634 |
635 | This program is free software: you can redistribute it and/or modify
636 | it under the terms of the GNU Affero General Public License as published
637 | by the Free Software Foundation, either version 3 of the License, or
638 | (at your option) any later version.
639 |
640 | This program is distributed in the hope that it will be useful,
641 | but WITHOUT ANY WARRANTY; without even the implied warranty of
642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643 | GNU Affero General Public License for more details.
644 |
645 | You should have received a copy of the GNU Affero General Public License
646 | along with this program. If not, see .
647 |
648 | Also add information on how to contact you by electronic and paper mail.
649 |
650 | If your software can interact with users remotely through a computer
651 | network, you should also make sure that it provides a way for users to
652 | get its source. For example, if your program is a web application, its
653 | interface could display a "Source" link that leads users to an archive
654 | of the code. There are many ways you could offer source, and different
655 | solutions will be better for different programs; see section 13 for the
656 | specific requirements.
657 |
658 | You should also get your employer (if you work as a programmer) or school,
659 | if any, to sign a "copyright disclaimer" for the program, if necessary.
660 | For more information on this, and how to apply and follow the GNU AGPL, see
661 | .
662 |
--------------------------------------------------------------------------------
/LICENSE.MIT:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Everydream, Rinon Gal, Yuval Alaluf, Yuval Atzmon, Or Patashnik and contributors
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | # Stable Tuner, Fine-tune your SD
8 |
9 |
10 |
11 | [](https://www.youtube.com/watch?v=36Z4ETFZx94)
12 |
13 | ### Join the Discord for training and chill ;)
14 |
15 | Stable Tuner wants to be the easiest and most complete Stable Diffusion tuner :)
16 |
17 | Features
18 |
19 | - **For End Users** - ST was made to provide a solution that is convenient but powerful on windows, if you wanted to try finetuning, there's no better option, for Linux folks, a bash script will be added at a later date if there's enough interest.
20 | - **More models, more fun** - ST Now supports Depth2Img and Inpainting models training!
21 | - **Your Training, Everywhere** - ST Now supports cloud training as well!, package up your data and start training on Runpod/Colab etc with a few clicks!
22 | - **Easy Installation** - ST makes installing convenient, using a bat file, ST will setup an environment ready for work and will install all the necessary components to get your training started fast!
23 | - **Friendly GUI** - ST features a full GUI to configure training runs, import and export settings, view tool tips for options, test your new model in the playground, convert the model to CKPT and more!
24 | - **Better Performance** - Using Diffusers, Xformers, CUDNN 1.8 and Bitsandbytes along with Latent caching allows for higher batch sizes and faster speeds, higher batch sizes = better quality model!.
25 | - **A Toolbox** - Use Caption Buddy to quickly generate and edit captions for your dataset in one streamlined tool, ST is building a toolbox for the must-have tools if you're training models.
26 | - **Fine Tuning Mindset** - ST is built to fine-tune, unlike Dreambooth, ST is meant to fine-tune a model, providing tools and settings to make most of your 3090/4090s, Dreambooth is still an option.
27 | - **Filename/Caption/Token based learning** - You can train using the individual file names as caption, use a caption txt file or a single token DB style, for finetuning file name and captions are best.
28 | - **Aspect Ratio Bucketing** - Using Aspect Ratio bucketing you can use any aspect ratio or resolution for your training images, images will get shuffled into buckets and resized to your chosen resolution target!, supports up to 1024 resolution!.
29 | - **Remote monitoring using Telegram** - Want to keep tabs on your training? set a bot up in Telegram and receive samples and notifications as you train,
30 | - **Better Sampling controls** - To gauge how your model is doing sampling is important, to that effect ST gives you the option to add sample prompts as you see fit, set the number of images to produce per prompt, send a controlled seed prompt (to gauge how a seed changes) or even use random aspect ratios to see how buckets are changing your generations!.
31 | - **Better Dataset Handling** - Use Dataset balancing to even out multiple concepts so they don't over-power each other, add class images to dataset to train them directly, override per dataset if necessary.
32 | - **Quality of life** - Many options to tune the experience to your liking, use save latent caching to avoid regenerating them at every run, use high batch-sizes to maximize training speed and performance, use epochs instead of steps to gauge progress better!.
33 | - **Built for Diffusers** - ST uses HF's Diffusers library to allow the best and fastest implementations going forward, as of now, training 1.4,1.5,2 and 2-768 work great.
34 |
35 | ## Installation
36 |
37 | Download and install Anaconda or miniconda and clone this repo, run the install_stabletuner.bat, when finished start the app with the StableTuner.cmd file.
38 |
39 | Note: If your anaconda is installed in a directory different from the standard installation directory, please create a text file called custom_conda_path.txt and put the path to your anaconda installation inside before running the install_stabletuner.bat file.
40 |
41 | ## CUDNN 8.6
42 |
43 | **NOTICE - As of this writing this step is no longer necessary and the installer will download CUDNN by itself, keeping it here for now**
44 |
45 | Due to the filesize I can't host the DLLs needed for CUDNN 8.6 on Github, I **strongly** advise you download them for a speed boost in sample generation (almost **50%** on 4090) you can download them from here: CUDNN 8.6
46 |
47 | To install simply unzip the directory and place in the same directory as StableTuner.cmd, run install_stabletuner.bat and you're good to go!
48 |
49 | ## Usage
50 |
51 | Refer to the tool tips in the GUI for more information, if you have any questions feel free to ask in the Discord
52 |
53 | ## Kudos
54 |
55 | - Shivam - For the original code and inspiration - A2 License
56 | - Diffusers - For the latest and greatest implementations - A2 License
57 | - Everydream - For the Aspect Ratio bucketing - MIT License
58 | - Sygil.dev - For the environment setup - GAPLV3 License
59 | - sd_dreambooth_extension - for the bitsandbytes files and install script
60 | - StabilityAI - For the latest and greatest models
61 | - The whole SD community - For making this possible
62 |
63 | ## What's next?
64 |
65 | - Linux support
66 | - More models
67 | - Advanced model mixing
68 | - And more! :D
69 | - Support me on Ko-Fi and come hang out in Discord to help me decide what's next :)
70 |
--------------------------------------------------------------------------------
/StableTuner.cmd:
--------------------------------------------------------------------------------
1 | @echo off
2 | :: This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
3 | ::
4 | :: Copyright 2022 Sygil-Dev team.
5 | :: This program is free software: you can redistribute it and/or modify
6 | :: it under the terms of the GNU Affero General Public License as published by
7 | :: the Free Software Foundation, either version 3 of the License, or
8 | :: (at your option) any later version.
9 | ::
10 | :: This program is distributed in the hope that it will be useful,
11 | :: but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | :: MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | :: GNU Affero General Public License for more details.
14 | ::
15 | :: You should have received a copy of the GNU Affero General Public License
16 | :: along with this program. If not, see .
17 | :: Run all commands using this script's directory as the working directory
18 | cd %~dp0
19 |
20 | :: copy over the first line from environment.yaml, e.g. name: ldm, and take the second word after splitting by ":" delimiter
21 | set v_conda_env_name="ST"
22 |
23 |
24 | echo Environment name is set as %v_conda_env_name% as per environment.yaml
25 |
26 | :: Put the path to conda directory in a file called "custom_conda_path.txt" if it's installed at non-standard path
27 | IF EXIST custom_conda_path.txt (
28 | FOR /F %%i IN (custom_conda_path.txt) DO set v_custom_path=%%i
29 | )
30 |
31 | set INSTALL_ENV_DIR=%cd%\installer_files\env
32 | set PATH=%INSTALL_ENV_DIR%;%INSTALL_ENV_DIR%\Library\bin;%INSTALL_ENV_DIR%\Scripts;%INSTALL_ENV_DIR%\Library\usr\bin;%PATH%
33 |
34 | set v_paths=%INSTALL_ENV_DIR%
35 | set v_paths=%v_paths%;%ProgramData%\miniconda3
36 | set v_paths=%v_paths%;%USERPROFILE%\miniconda3
37 | set v_paths=%v_paths%;%ProgramData%\anaconda3
38 | set v_paths=%v_paths%;%USERPROFILE%\anaconda3
39 |
40 | for %%a in (%v_paths%) do (
41 | IF NOT "%v_custom_path%"=="" (
42 | set v_paths=%v_custom_path%;%v_paths%
43 | )
44 | )
45 |
46 | for %%a in (%v_paths%) do (
47 | if EXIST "%%a\Scripts\activate.bat" (
48 | SET v_conda_path=%%a
49 | echo anaconda3/miniconda3 detected in %%a
50 | goto :CONDA_FOUND
51 | )
52 | )
53 |
54 | IF "%v_conda_path%"=="" (
55 | echo anaconda3/miniconda3 not found. Install from here https://docs.conda.io/en/latest/miniconda.html
56 | pause
57 | exit /b 1
58 | )
59 |
60 | :CONDA_FOUND
61 | echo Starting conda environment %v_conda_env_name% from %v_conda_path%
62 |
63 | call "%v_conda_path%\Scripts\activate.bat" "%v_conda_env_name%"
64 |
65 | ::call git pull | findstr /r /c:"changed" && set "HasChanges=1"
66 | ::IF "%HasChanges%" == "0" GOTO START_GUI
67 | ::echo StableTuner updated, running installer!
68 | ::call conda env create --name "%v_conda_env_name%" -f environment.yaml
69 | ::call conda env update --name "%v_conda_env_name%" -f environment.yaml
70 | ::python windows_install.py
71 | :START_GUI
72 | ::set HasChanges=0
73 | python scripts/configuration_gui.py
74 |
75 | ::cmd /k
76 |
--------------------------------------------------------------------------------
/configs/4090_SD2_512_dreambooth.json:
--------------------------------------------------------------------------------
1 | {
2 | "concepts": [
3 | {
4 | "instance_prompt": "token",
5 | "class_prompt": "class token",
6 | "instance_data_dir": "data path",
7 | "class_data_dir": "class data path",
8 | "do_not_balance": 0
9 | }
10 | ],
11 | "sample_prompts": [],
12 | "add_controlled_seed_to_sample": [
13 | "3434554"
14 | ],
15 | "model_path": "stabilityai/stable-diffusion-2-base",
16 | "vae_path": "stabilityai/sd-vae-ft-mse",
17 | "output_path": "models/art_test",
18 | "send_telegram_updates": 0,
19 | "telegram_token": "",
20 | "telegram_chat_id": "",
21 | "resolution": "512",
22 | "batch_size": "16",
23 | "train_epocs": "30",
24 | "mixed_precision": "fp16",
25 | "use_8bit_adam": 1,
26 | "use_gradient_checkpointing": 1,
27 | "accumulation_steps": "1",
28 | "learning_rate": "5e-6",
29 | "warmup_steps": "0",
30 | "learning_rate_scheduler": "constant",
31 | "use_latent_cache": 1,
32 | "save_latent_cache": 1,
33 | "regenerate_latent_cache": 0,
34 | "train_text_encoder": 1,
35 | "with_prior_loss_preservation": 1,
36 | "prior_loss_preservation_weight": "1.0",
37 | "use_image_names_as_captions": 1,
38 | "auto_balance_concept_datasets": 1,
39 | "add_class_images_to_dataset": 0,
40 | "number_of_class_images": "200",
41 | "save_every_n_epochs": "5",
42 | "number_of_samples_to_generate": "1",
43 | "sample_height": "512",
44 | "sample_width": "512",
45 | "sample_random_aspect_ratio": 0,
46 | "save_on_training_start": 0,
47 | "aspect_ratio_bucketing": 1,
48 | "seed": "12354"
49 | }
--------------------------------------------------------------------------------
/configs/4090_SD2_512_finetune.json:
--------------------------------------------------------------------------------
1 | {
2 | "concepts": [
3 | {
4 | "instance_prompt": "token",
5 | "class_prompt": "class token",
6 | "instance_data_dir": "data path",
7 | "class_data_dir": "class data path",
8 | "do_not_balance": 0
9 | }
10 | ],
11 | "sample_prompts": [],
12 | "add_controlled_seed_to_sample": [
13 | "3434554"
14 | ],
15 | "model_path": "stabilityai/stable-diffusion-2-base",
16 | "vae_path": "stabilityai/sd-vae-ft-mse",
17 | "output_path": "models/art_test",
18 | "send_telegram_updates": 0,
19 | "telegram_token": "",
20 | "telegram_chat_id": "",
21 | "resolution": "512",
22 | "batch_size": "24",
23 | "train_epocs": "200",
24 | "mixed_precision": "fp16",
25 | "use_8bit_adam": 1,
26 | "use_gradient_checkpointing": 1,
27 | "accumulation_steps": "1",
28 | "learning_rate": "1e-6",
29 | "warmup_steps": "0",
30 | "learning_rate_scheduler": "constant",
31 | "use_latent_cache": 1,
32 | "save_latent_cache": 1,
33 | "regenerate_latent_cache": 0,
34 | "train_text_encoder": 1,
35 | "with_prior_loss_preservation": 0,
36 | "prior_loss_preservation_weight": "1.0",
37 | "use_image_names_as_captions": 1,
38 | "auto_balance_concept_datasets": 1,
39 | "add_class_images_to_dataset": 0,
40 | "number_of_class_images": "200",
41 | "save_every_n_epochs": "15",
42 | "number_of_samples_to_generate": "1",
43 | "sample_height": "512",
44 | "sample_width": "512",
45 | "sample_random_aspect_ratio": 0,
46 | "save_on_training_start": 0,
47 | "aspect_ratio_bucketing": 1,
48 | "seed": "12354"
49 | }
--------------------------------------------------------------------------------
/configs/test.json:
--------------------------------------------------------------------------------
1 | {
2 | "concepts": [],
3 | "sample_prompts": [],
4 | "add_controlled_seed_to_sample": [
5 | "3434554"
6 | ],
7 | "model_path": "stabilityai/stable-diffusion-2",
8 | "vae_path": "stabilityai/sd-vae-ft-mse",
9 | "output_path": "models/art_test",
10 | "send_telegram_updates": 1,
11 | "telegram_token": "5587505616:AAGsjLRzjx5YQaVNwc88mAfLyPZRPXO505k",
12 | "telegram_chat_id": "-819756645",
13 | "resolution": "768",
14 | "batch_size": "16",
15 | "train_epocs": "200",
16 | "mixed_precision": "fp16",
17 | "use_8bit_adam": 1,
18 | "use_gradient_checkpointing": 1,
19 | "accumulation_steps": "1",
20 | "learning_rate": "1e-6",
21 | "warmup_steps": "0",
22 | "learning_rate_scheduler": "constant",
23 | "use_latent_cache": 1,
24 | "save_latent_cache": 1,
25 | "regenerate_latent_cache": 1,
26 | "train_text_encoder": 1,
27 | "with_prior_loss_preservation": 0,
28 | "prior_loss_preservation_weight": "1.0",
29 | "use_image_names_as_captions": 1,
30 | "auto_balance_concept_datasets": 1,
31 | "add_class_images_to_dataset": 0,
32 | "number_of_class_images": "200",
33 | "save_every_n_epochs": "15",
34 | "number_of_samples_to_generate": "1",
35 | "sample_height": "768",
36 | "sample_width": "768",
37 | "sample_random_aspect_ratio": 0,
38 | "save_on_training_start": 0,
39 | "aspect_ratio_bucketing": 0,
40 | "seed": "12354"
41 | }
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: ST
2 | # This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
3 |
4 | # Copyright 2022 Sygil-Dev team.
5 | # This program is free software: you can redistribute it and/or modify
6 | # it under the terms of the GNU Affero General Public License as published by
7 | # the Free Software Foundation, either version 3 of the License, or
8 | # (at your option) any later version.
9 |
10 | # This program is distributed in the hope that it will be useful,
11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | # GNU Affero General Public License for more details.
14 |
15 | # You should have received a copy of the GNU Affero General Public License
16 | # along with this program. If not, see .
17 | channels:
18 | - conda-forge
19 | - defaults
20 | - nvidia
21 | # Psst. If you change a dependency, make sure it's mirrored in the docker requirement
22 | # files as well.
23 | dependencies:
24 | - nodejs
25 | - yarn
26 | - cudatoolkit
27 | - git
28 | - numpy
29 | - pip
30 | - python=3.10.0
31 | - scikit-image
32 | - pip:
33 | - -r requirements.txt
--------------------------------------------------------------------------------
/install_stabletuner.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | :: This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
3 | ::
4 | :: Copyright 2022 Sygil-Dev team.
5 | :: This program is free software: you can redistribute it and/or modify
6 | :: it under the terms of the GNU Affero General Public License as published by
7 | :: the Free Software Foundation, either version 3 of the License, or
8 | :: (at your option) any later version.
9 | ::
10 | :: This program is distributed in the hope that it will be useful,
11 | :: but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | :: MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | :: GNU Affero General Public License for more details.
14 | ::
15 | :: You should have received a copy of the GNU Affero General Public License
16 | :: along with this program. If not, see .
17 | :: Run all commands using this script's directory as the working directory
18 | cd %~dp0
19 |
20 | :: copy over the first line from environment.yaml, e.g. name: ldm, and take the second word after splitting by ":" delimiter
21 | for /F "tokens=2 delims=: " %%i in (environment.yaml) DO (
22 | set v_conda_env_name=%%i
23 | goto EOL
24 | )
25 | :EOL
26 |
27 | echo Environment name is set as %v_conda_env_name% as per environment.yaml
28 |
29 | :: Put the path to conda directory in a file called "custom_conda_path.txt" if it's installed at non-standard path
30 | IF EXIST custom_conda_path.txt (
31 | FOR /F %%i IN (custom_conda_path.txt) DO set v_custom_path=%%i
32 | )
33 |
34 | set INSTALL_ENV_DIR=%cd%\installer_files\env
35 | set PATH=%INSTALL_ENV_DIR%;%INSTALL_ENV_DIR%\Library\bin;%INSTALL_ENV_DIR%\Scripts;%INSTALL_ENV_DIR%\Library\usr\bin;%PATH%
36 |
37 | set v_paths=%INSTALL_ENV_DIR%
38 | set v_paths=%v_paths%;%ProgramData%\miniconda3
39 | set v_paths=%v_paths%;%USERPROFILE%\miniconda3
40 | set v_paths=%v_paths%;%ProgramData%\anaconda3
41 | set v_paths=%v_paths%;%USERPROFILE%\anaconda3
42 |
43 | for %%a in (%v_paths%) do (
44 | IF NOT "%v_custom_path%"=="" (
45 | set v_paths=%v_custom_path%;%v_paths%
46 | )
47 | )
48 |
49 | for %%a in (%v_paths%) do (
50 | if EXIST "%%a\Scripts\activate.bat" (
51 | SET v_conda_path=%%a
52 | echo anaconda3/miniconda3 detected in %%a
53 | goto :CONDA_FOUND
54 | )
55 | )
56 |
57 | IF "%v_conda_path%"=="" (
58 | echo anaconda3/miniconda3 not found. Install from here https://docs.conda.io/en/latest/miniconda.html
59 | pause
60 | exit /b 1
61 | )
62 |
63 | :CONDA_FOUND
64 | echo Found Anaconda
65 |
66 | :SKIP_RESTORE
67 | call "%v_conda_path%\Scripts\activate.bat"
68 |
69 |
70 | call conda env create --name "%v_conda_env_name%" -f environment.yaml
71 |
72 |
73 |
74 | call "%v_conda_path%\Scripts\activate.bat" "%v_conda_env_name%"
75 |
76 | :PROMPT
77 | python scripts/windows_install.py
78 | pause
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.15.0
2 | transformers>=4.25.1
3 | ftfy==6.1.1
4 | albumentations==1.3.0
5 | opencv-python==4.6.0.66
6 | einops==0.6.0
7 | pytorch_lightning==1.8.5.post0
8 | bitsandbytes==0.35.0
9 | tensorboard==2.11.0
10 | gitpython==3.1.29
11 | fairscale==0.4.13
12 | timm==0.6.12
13 | OmegaConf==2.3.0
14 | safetensors==0.2.6
15 | customtkinter==5.0.3
16 | tokenizers==0.13.2
17 | pyperclip==1.8.2
18 | gradio
19 | keyboard
20 | huggingface
--------------------------------------------------------------------------------
/resources/accelerate_windows/accelerate_default_config.yaml:
--------------------------------------------------------------------------------
1 | command_file: null
2 | commands: null
3 | compute_environment: LOCAL_MACHINE
4 | deepspeed_config: {}
5 | distributed_type: 'NO'
6 | downcast_bf16: 'no'
7 | dynamo_backend: 'NO'
8 | fsdp_config: {}
9 | gpu_ids: all
10 | machine_rank: 0
11 | main_process_ip: null
12 | main_process_port: null
13 | main_training_function: main
14 | megatron_lm_config: {}
15 | mixed_precision: fp16
16 | num_machines: 1
17 | num_processes: 1
18 | rdzv_backend: static
19 | same_network: true
20 | tpu_name: null
21 | tpu_zone: null
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/resources/bitsandbytes_windows/cextension.py:
--------------------------------------------------------------------------------
1 | import ctypes as ct
2 | from pathlib import Path
3 | from warnings import warn
4 |
5 | from .cuda_setup.main import evaluate_cuda_setup
6 |
7 |
8 | class CUDALibrary_Singleton(object):
9 | _instance = None
10 |
11 | def __init__(self):
12 | raise RuntimeError("Call get_instance() instead")
13 |
14 | def initialize(self):
15 | binary_name = evaluate_cuda_setup()
16 | package_dir = Path(__file__).parent
17 | binary_path = package_dir / binary_name
18 |
19 | if not binary_path.exists():
20 | print(f"CUDA SETUP: TODO: compile library for specific version: {binary_name}")
21 | legacy_binary_name = "libbitsandbytes.so"
22 | print(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
23 | binary_path = package_dir / legacy_binary_name
24 | if not binary_path.exists():
25 | print('CUDA SETUP: CUDA detection failed. Either CUDA driver not installed, CUDA not installed, or you have multiple conflicting CUDA libraries!')
26 | print('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
27 | raise Exception('CUDA SETUP: Setup Failed!')
28 | # self.lib = ct.cdll.LoadLibrary(binary_path)
29 | self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
30 | else:
31 | print(f"CUDA SETUP: Loading binary {binary_path}...")
32 | # self.lib = ct.cdll.LoadLibrary(binary_path)
33 | self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
34 |
35 | @classmethod
36 | def get_instance(cls):
37 | if cls._instance is None:
38 | cls._instance = cls.__new__(cls)
39 | cls._instance.initialize()
40 | return cls._instance
41 |
42 |
43 | lib = CUDALibrary_Singleton.get_instance().lib
44 | try:
45 | lib.cadam32bit_g32
46 | lib.get_context.restype = ct.c_void_p
47 | lib.get_cusparse.restype = ct.c_void_p
48 | COMPILED_WITH_CUDA = True
49 | except AttributeError:
50 | warn(
51 | "The installed version of bitsandbytes was compiled without GPU support. "
52 | "8-bit optimizers and GPU quantization are unavailable."
53 | )
54 | COMPILED_WITH_CUDA = False
55 |
--------------------------------------------------------------------------------
/resources/bitsandbytes_windows/libbitsandbytes_cpu.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devilismyfriend/StableTuner/4c63853399590289ea88b6fcd6565f1a407c916c/resources/bitsandbytes_windows/libbitsandbytes_cpu.dll
--------------------------------------------------------------------------------
/resources/bitsandbytes_windows/libbitsandbytes_cuda116.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devilismyfriend/StableTuner/4c63853399590289ea88b6fcd6565f1a407c916c/resources/bitsandbytes_windows/libbitsandbytes_cuda116.dll
--------------------------------------------------------------------------------
/resources/bitsandbytes_windows/main.py:
--------------------------------------------------------------------------------
1 | """
2 | extract factors the build is dependent on:
3 | [X] compute capability
4 | [ ] TODO: Q - What if we have multiple GPUs of different makes?
5 | - CUDA version
6 | - Software:
7 | - CPU-only: only CPU quantization functions (no optimizer, no matrix multipl)
8 | - CuBLAS-LT: full-build 8-bit optimizer
9 | - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
10 |
11 | evaluation:
12 | - if paths faulty, return meaningful error
13 | - else:
14 | - determine CUDA version
15 | - determine capabilities
16 | - based on that set the default path
17 | """
18 |
19 | import ctypes
20 |
21 | from .paths import determine_cuda_runtime_lib_path
22 |
23 |
24 | def check_cuda_result(cuda, result_val):
25 | # 3. Check for CUDA errors
26 | if result_val != 0:
27 | error_str = ctypes.c_char_p()
28 | cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
29 | print(f"CUDA exception! Error code: {error_str.value.decode()}")
30 |
31 | def get_cuda_version(cuda, cudart_path):
32 | # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
33 | try:
34 | cudart = ctypes.CDLL(cudart_path)
35 | except OSError:
36 | # TODO: shouldn't we error or at least warn here?
37 | print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
38 | return None
39 |
40 | version = ctypes.c_int()
41 | check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version)))
42 | version = int(version.value)
43 | major = version//1000
44 | minor = (version-(major*1000))//10
45 |
46 | if major < 11:
47 | print('CUDA SETUP: CUDA version lower than 11 are currenlty not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
48 |
49 | return f'{major}{minor}'
50 |
51 |
52 | def get_cuda_lib_handle():
53 | # 1. find libcuda.so library (GPU driver) (/usr/lib)
54 | try:
55 | cuda = ctypes.CDLL("libcuda.so")
56 | except OSError:
57 | # TODO: shouldn't we error or at least warn here?
58 | print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
59 | return None
60 | check_cuda_result(cuda, cuda.cuInit(0))
61 |
62 | return cuda
63 |
64 |
65 | def get_compute_capabilities(cuda):
66 | """
67 | 1. find libcuda.so library (GPU driver) (/usr/lib)
68 | init_device -> init variables -> call function by reference
69 | 2. call extern C function to determine CC
70 | (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
71 | 3. Check for CUDA errors
72 | https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
73 | # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
74 | """
75 |
76 |
77 | nGpus = ctypes.c_int()
78 | cc_major = ctypes.c_int()
79 | cc_minor = ctypes.c_int()
80 |
81 | device = ctypes.c_int()
82 |
83 | check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
84 | ccs = []
85 | for i in range(nGpus.value):
86 | check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
87 | ref_major = ctypes.byref(cc_major)
88 | ref_minor = ctypes.byref(cc_minor)
89 | # 2. call extern C function to determine CC
90 | check_cuda_result(
91 | cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)
92 | )
93 | ccs.append(f"{cc_major.value}.{cc_minor.value}")
94 |
95 | return ccs
96 |
97 |
98 | # def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
99 | def get_compute_capability(cuda):
100 | """
101 | Extracts the highest compute capbility from all available GPUs, as compute
102 | capabilities are downwards compatible. If no GPUs are detected, it returns
103 | None.
104 | """
105 | ccs = get_compute_capabilities(cuda)
106 | if ccs is not None:
107 | # TODO: handle different compute capabilities; for now, take the max
108 | return ccs[-1]
109 | return None
110 |
111 |
112 | def evaluate_cuda_setup():
113 | print('')
114 | print('='*35 + 'BUG REPORT' + '='*35)
115 | print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
116 | print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
117 | print('='*80)
118 | return "libbitsandbytes_cuda116.dll" # $$$
119 |
120 | binary_name = "libbitsandbytes_cpu.so"
121 | #if not torch.cuda.is_available():
122 | #print('No GPU detected. Loading CPU library...')
123 | #return binary_name
124 |
125 | cudart_path = determine_cuda_runtime_lib_path()
126 | if cudart_path is None:
127 | print(
128 | "WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!"
129 | )
130 | return binary_name
131 |
132 | print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
133 | cuda = get_cuda_lib_handle()
134 | cc = get_compute_capability(cuda)
135 | print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
136 | cuda_version_string = get_cuda_version(cuda, cudart_path)
137 |
138 |
139 | if cc == '':
140 | print(
141 | "WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
142 | )
143 | return binary_name
144 |
145 | # 7.5 is the minimum CC vor cublaslt
146 | has_cublaslt = cc in ["7.5", "8.0", "8.6"]
147 |
148 | # TODO:
149 | # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
150 | # (2) Multiple CUDA versions installed
151 |
152 | # we use ls -l instead of nvcc to determine the cuda version
153 | # since most installations will have the libcudart.so installed, but not the compiler
154 | print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
155 |
156 | def get_binary_name():
157 | "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
158 | bin_base_name = "libbitsandbytes_cuda"
159 | if has_cublaslt:
160 | return f"{bin_base_name}{cuda_version_string}.so"
161 | else:
162 | return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"
163 |
164 | binary_name = get_binary_name()
165 |
166 | return binary_name
167 |
--------------------------------------------------------------------------------
/resources/stableTuner_icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devilismyfriend/StableTuner/4c63853399590289ea88b6fcd6565f1a407c916c/resources/stableTuner_icon.png
--------------------------------------------------------------------------------
/resources/stableTuner_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devilismyfriend/StableTuner/4c63853399590289ea88b6fcd6565f1a407c916c/resources/stableTuner_logo.png
--------------------------------------------------------------------------------
/resources/stableTuner_notebook.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "\n",
9 | "# Welcome to StableTuner, Let's get started!\n",
10 | "#### This notebook will guide you through the setup process.\n",
11 | "\n",
12 | "\n",
13 | "__[Join the ST Discord for support, chat and fun times :)](https://discord.gg/DahNECrBUZ)__"
14 | ]
15 | },
16 | {
17 | "attachments": {},
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "#### Start by uploading your payload.zip file (just drag and drop it to the file area) and run this cell as it gets uploaded."
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "#Much thanks to IndustriaDitat and entmike for helping making ST linux compatible!\n",
31 | "from IPython.display import clear_output\n",
32 | "from subprocess import getoutput\n",
33 | "installed_xformers = False\n",
34 | "GPU_CardName = getoutput('nvidia-smi --query-gpu=name --format=csv,noheader')\n",
35 | "\n",
36 | "%pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url \"https://download.pytorch.org/whl/cu116\"\n",
37 | "%pip install -U --pre triton\n",
38 | "%pip install ninja bitsandbytes\n",
39 | "if '4090' in GPU_CardName:\n",
40 | " %pip install https://huggingface.co/industriaditat/xformers_precompiles/resolve/main/RTX4090-xf14-cu116-py38/xformers-0.0.14.dev0-cp38-cp38-linux_x86_64.whl\n",
41 | " installed_xformers = True\n",
42 | "if '3090' in GPU_CardName:\n",
43 | " %pip install https://huggingface.co/industriaditat/xformers_precompiles/resolve/main/RTX3090-xf14-cu116-py38/xformers-0.0.14.dev0-cp38-cp38-linux_x86_64.whl\n",
44 | " installed_xformers = True\n",
45 | "if 'A5000' in GPU_CardName:\n",
46 | " %pip install https://huggingface.co/industriaditat/xformers_precompiles/resolve/main/A5000-xf14-cu116-py38/xformers-0.0.14.dev0-cp38-cp38-linux_x86_64.whl\n",
47 | " installed_xformers = True\n",
48 | "if 'T4' in GPU_CardName:\n",
49 | " %pip install https://huggingface.co/r4ziel/xformers_pre_built/resolve/main/xformers-0.0.14.dev0-cp38-cp38-linux_x86_64_t4.whl\n",
50 | " installed_xformers = True\n",
51 | "if 'A100' in GPU_CardName:\n",
52 | " %pip install https://huggingface.co/industriaditat/xformers_precompiles/resolve/main/A100_13dev/xformers-0.0.13.dev0-py3-none-any.whl\n",
53 | " installed_xformers = True\n",
54 | "if 'V100' in GPU_CardName:\n",
55 | " %pip install https://huggingface.co/industriaditat/xformers_precompiles/resolve/main/V100_13dev/xformers-0.0.13.dev0-py3-none-any.whl\n",
56 | " installed_xformers = True\n",
57 | "if installed_xformers == False:\n",
58 | " clear_output()\n",
59 | " print(\"No precompiled xformers found for your GPU. Please wait while we compile xformers for your GPU, this might take 20-40 minutes.\")\n",
60 | " %pip install git+https://github.com/facebookresearch/xformers@1d31a3a#egg=xformers\n",
61 | "%pip install git+https://github.com/huggingface/diffusers.git@0ca1724#egg=diffusers --force-reinstall\n",
62 | "clear_output()\n",
63 | "print(\"Done!\")"
64 | ]
65 | },
66 | {
67 | "attachments": {},
68 | "cell_type": "markdown",
69 | "metadata": {},
70 | "source": [
71 | "#### Upload finished?, time to run this next cell!"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "!unzip -o -q payload.zip\n",
81 | "%pip install -r requirements.txt\n",
82 | "clear_output()\n",
83 | "print(\"Done!\")"
84 | ]
85 | },
86 | {
87 | "attachments": {},
88 | "cell_type": "markdown",
89 | "metadata": {},
90 | "source": [
91 | "#### Looks like you're done installing, let's get training!\n",
92 | "\n"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "changeMe"
102 | ]
103 | },
104 | {
105 | "attachments": {},
106 | "cell_type": "markdown",
107 | "metadata": {},
108 | "source": [
109 | "## Model Playground\n",
110 | "\n",
111 | "#### This is where you can test your model and package it up."
112 | ]
113 | },
114 | {
115 | "attachments": {},
116 | "cell_type": "markdown",
117 | "metadata": {},
118 | "source": [
119 | "Run this cell and select your output model, you can upload the model to HuggingFace or run the next cell to use the Web UI and play around with your model."
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "metadata": {},
126 | "outputs": [],
127 | "source": [
128 | "from ipywidgets import widgets\n",
129 | "import os\n",
130 | "import glob\n",
131 | "from IPython.display import clear_output\n",
132 | "import torch\n",
133 | "from torch import autocast\n",
134 | "from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler\n",
135 | "from IPython.display import display\n",
136 | "import random\n",
137 | "import gradio as gr\n",
138 | "from requests import HTTPError\n",
139 | "from huggingface_hub import create_repo\n",
140 | "from huggingface_hub import HfApi\n",
141 | "from huggingface_hub import login\n",
142 | "from huggingface_hub import logging\n",
143 | "from huggingface_hub.utils import hf_raise_for_status, HfHubHTTPError\n",
144 | "def upload_hf(token,repo_name,model_path):\n",
145 | " if token == '':\n",
146 | " print(\"Please enter your HuggingFace token.\")\n",
147 | " return\n",
148 | " if repo_name == '':\n",
149 | " print(\"Please enter your model name.\")\n",
150 | " return\n",
151 | " if model_path == '':\n",
152 | " print(\"Couldn't find a valid model\")\n",
153 | " return\n",
154 | " api = login(token=token)\n",
155 | " api = HfApi()\n",
156 | " hf_username = HfApi().whoami(token)['name']\n",
157 | " project_repo = repo_name\n",
158 | " try:\n",
159 | " create_repo(f'{hf_username}/{project_repo}', repo_type='model',token=token)\n",
160 | " except HTTPError as http_e:\n",
161 | " if http_e.response.status_code == 409:\n",
162 | " print('The repo already exists')\n",
163 | " pass\n",
164 | " else:\n",
165 | " print(f'An error occurred: {http_e.reason}')\n",
166 | " except HfHubHTTPError as hfhub_e:\n",
167 | " if hfhub_e.response.status_code == 409:\n",
168 | " pass\n",
169 | " else:\n",
170 | " print(f'An error occurred: {hfhub_e.message}')\n",
171 | " try:\n",
172 | " print('Uploading...')\n",
173 | " api.upload_folder(\n",
174 | " folder_path=model_path,\n",
175 | " path_in_repo='',\n",
176 | " repo_id=f'{hf_username}/{project_repo}',\n",
177 | " repo_type=\"model\",\n",
178 | " ignore_patterns=\"**/logs/\",\n",
179 | " )\n",
180 | " print('Done!')\n",
181 | " print(f'Model is at https://huggingface.co/{hf_username}/{project_repo}')\n",
182 | " except Exception as general_e:\n",
183 | " print(f'Exception occurred: {general_e}')\n",
184 | "if 'output' not in os.listdir():\n",
185 | " print(\"No output folder found. Please run the training cell first.\")\n",
186 | "models = []\n",
187 | "model_dir = os.listdir('output')[0]\n",
188 | "output_sort = sorted(glob.iglob('output' + os.sep + model_dir + os.sep+ '*'), key=os.path.getctime, reverse=True)\n",
189 | "if len(output_sort) == 0:\n",
190 | " print(\"No models found in output folder. Please run the training cell first.\")\n",
191 | "for model in output_sort:\n",
192 | " required_folders = [\"vae\", \"unet\", \"tokenizer\", \"text_encoder\"]\n",
193 | " if all(x in os.listdir(model) for x in required_folders):\n",
194 | " models.append(model)\n",
195 | "model_selection = widgets.Dropdown(\n",
196 | " layout={'width': 'initial'},\n",
197 | " style={'description_width': 'initial'},\n",
198 | " options=models,\n",
199 | " value=models[0],\n",
200 | " # rows=10,\n",
201 | " description='Select Checkpoint:',\n",
202 | " disabled=False\n",
203 | ")\n",
204 | "upload_btn = widgets.Button(\n",
205 | " description='Upload to HuggingFace Hub',\n",
206 | " style={'description_width': 'initial'},\n",
207 | " layout={'width': 'initial'},\n",
208 | " disabled=False,\n",
209 | " button_style='', # 'success', 'info', 'warning', 'danger' or ''\n",
210 | " tooltip='Press to start upload',\n",
211 | " icon='check' # (FontAwesome names without the `fa-` prefix)\n",
212 | ")\n",
213 | "token_txt = widgets.Text(\n",
214 | " value='',\n",
215 | " style={'description_width': 'initial'},\n",
216 | " layout={'width': 'initial'},\n",
217 | " placeholder='HF Token',\n",
218 | " description='Hugging Face Token:',\n",
219 | " disabled=False\n",
220 | ")\n",
221 | "repo_txt = widgets.Text(\n",
222 | " value=model_dir,\n",
223 | " style={'description_width': 'initial'},\n",
224 | " placeholder='Give your model a name',\n",
225 | " description='Model Name:',\n",
226 | " disabled=False\n",
227 | ")\n",
228 | "upload_btn.on_click(lambda x: upload_hf(token_txt.value,repo_txt.value,model_selection.value))\n",
229 | "clear_output()\n",
230 | "display(model_selection)\n",
231 | "display(token_txt)\n",
232 | "display(repo_txt)\n",
233 | "display(upload_btn)\n",
234 | "print('You can input your HuggingFace token and repo to upload your model to the HuggingFace Hub, make sure to use a write API token!')\n",
235 | "print('Alternatively, you can run the next cell to open up a UI where you can generate and zip up the model.')\n"
236 | ]
237 | },
238 | {
239 | "attachments": {},
240 | "cell_type": "markdown",
241 | "metadata": {},
242 | "source": [
243 | "All ready?, lets play around with it ;), this next cell will load a small UI for you to generate images and zip it up for download, re-run this cell if you selected a new model!"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": null,
249 | "metadata": {},
250 | "outputs": [],
251 | "source": [
252 | "model_path = model_selection.value\n",
253 | "pipe = StableDiffusionPipeline.from_pretrained(model_path,safety_checker=None, torch_dtype=torch.float16).to(\"cuda\")\n",
254 | "scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n",
255 | "pipe.scheduler = scheduler\n",
256 | "print('Loaded checkpoint')\n",
257 | "def inference(prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50,seed=-1,guidance_scale=7.5):\n",
258 | " with torch.autocast(\"cuda\"), torch.inference_mode():\n",
259 | " if seed != -1:\n",
260 | " g_cuda = torch.Generator(device='cuda')\n",
261 | " g_cuda.manual_seed(int(seed))\n",
262 | " else:\n",
263 | " seed = random.randint(0, 100000)\n",
264 | " g_cuda = torch.Generator(device='cuda')\n",
265 | " g_cuda.manual_seed(seed)\n",
266 | " return pipe(\n",
267 | " prompt, height=int(height), width=int(width),\n",
268 | " negative_prompt=negative_prompt,\n",
269 | " num_images_per_prompt=int(num_samples),\n",
270 | " num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,\n",
271 | " generator=g_cuda\n",
272 | " ).images, seed\n",
273 | "def zip_model():\n",
274 | " import shutil\n",
275 | " print('Zipping Model!, Please wait until you see a done message, this can take a few minutes, you can keep generating while you wait!')\n",
276 | " curLocation = os.getcwd()\n",
277 | " model_name = os.path.dirname(model_path)\n",
278 | " shutil.make_archive(model_name,'zip',model_path)\n",
279 | " os.chdir(curLocation)\n",
280 | " print('Done!')\n",
281 | "with gr.Blocks() as demo:\n",
282 | " with gr.Row():\n",
283 | " with gr.Column():\n",
284 | " prompt = gr.Textbox(label=\"Prompt\", value=\"photo of zwx dog in a bucket\")\n",
285 | " negative_prompt = gr.Textbox(label=\"Negative Prompt\", value=\"\")\n",
286 | " with gr.Row():\n",
287 | " run = gr.Button(value=\"Generate\")\n",
288 | " zip = gr.Button(value=\"Zip Model For Download\")\n",
289 | " with gr.Row():\n",
290 | " num_samples = gr.Number(label=\"Number of Samples\", value=4)\n",
291 | " guidance_scale = gr.Number(label=\"Guidance Scale\", value=7.5)\n",
292 | " with gr.Row():\n",
293 | " height = gr.Number(label=\"Height\", value=512)\n",
294 | " width = gr.Number(label=\"Width\", value=512)\n",
295 | " with gr.Row():\n",
296 | " num_inference_steps = gr.Slider(label=\"Steps\", value=25)\n",
297 | " seed = gr.Number(label=\"Seed\", value=-1)\n",
298 | " with gr.Column():\n",
299 | " gallery = gr.Gallery()\n",
300 | " seedDisplay = gr.Number(label=\"Used Seed:\", value=0)\n",
301 | "\n",
302 | " run.click(inference, inputs=[prompt, negative_prompt, num_samples, height, width, num_inference_steps,seed, guidance_scale], outputs=[gallery,seedDisplay])\n",
303 | " zip.click(zip_model)\n",
304 | "demo.launch(debug=True,share=True)"
305 | ]
306 | },
307 | {
308 | "attachments": {},
309 | "cell_type": "markdown",
310 | "metadata": {},
311 | "source": [
312 | "### This is the end, for now :) ,you can convert your model to CKPT back in StableTuner!."
313 | ]
314 | }
315 | ],
316 | "metadata": {
317 | "kernelspec": {
318 | "display_name": "Python 3",
319 | "language": "python",
320 | "name": "python3"
321 | },
322 | "language_info": {
323 | "codemirror_mode": {
324 | "name": "ipython",
325 | "version": 3
326 | },
327 | "file_extension": ".py",
328 | "mimetype": "text/x-python",
329 | "name": "python",
330 | "nbconvert_exporter": "python",
331 | "pygments_lexer": "ipython3",
332 | "version": "3.10.9 (tags/v3.10.9:1dd9be6, Dec 6 2022, 20:01:21) [MSC v.1934 64 bit (AMD64)]"
333 | },
334 | "orig_nbformat": 4,
335 | "vscode": {
336 | "interpreter": {
337 | "hash": "886cb931ea414ad2a87adcccbb1ce9166879eb6056301acd331591c6290ceca8"
338 | }
339 | }
340 | },
341 | "nbformat": 4,
342 | "nbformat_minor": 2
343 | }
344 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/devilismyfriend/StableTuner/4c63853399590289ea88b6fcd6565f1a407c916c/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/captionBuddy.py:
--------------------------------------------------------------------------------
1 | import tkinter as tk
2 | from tkinter import ttk, Menu
3 | import os
4 | import subprocess
5 | from PIL import Image, ImageTk, ImageDraw
6 | import tkinter.filedialog as fd
7 | import json
8 | import sys
9 | import os
10 | import sys
11 | from torchvision import transforms
12 | from torchvision.transforms.functional import InterpolationMode
13 | import torch
14 | import subprocess
15 | import numpy as np
16 | import requests
17 | import random
18 | import customtkinter as ctk
19 | from customtkinter import ThemeManager
20 |
21 | from clip_segmentation import ClipSeg
22 |
23 | #main class
24 | ctk.set_appearance_mode("dark")
25 | ctk.set_default_color_theme("blue")
26 |
27 | class BatchMaskWindow(ctk.CTkToplevel):
28 | def __init__(self, parent, path, *args, **kwargs):
29 | ctk.CTkToplevel.__init__(self, parent, *args, **kwargs)
30 | self.parent = parent
31 |
32 | self.title("Batch process masks")
33 | self.geometry("320x310")
34 | self.resizable(False, False)
35 | self.wait_visibility()
36 | self.grab_set()
37 | self.focus_set()
38 |
39 | self.mode_var = tk.StringVar(self, "Create if absent")
40 | self.modes = ["Replace all masks", "Create if absent", "Add to existing", "Subtract from existing"]
41 |
42 | self.frame = ctk.CTkFrame(self, width=600, height=300)
43 | self.frame.grid(row=0, column=0, sticky="nsew", padx=10, pady=10)
44 |
45 | self.path_label = ctk.CTkLabel(self.frame, text="Folder", width=100)
46 | self.path_label.grid(row=0, column=0, sticky="w",padx=5, pady=5)
47 | self.path_entry = ctk.CTkEntry(self.frame, width=150)
48 | self.path_entry.insert(0, path)
49 | self.path_entry.grid(row=0, column=1, sticky="w", padx=5, pady=5)
50 | self.path_button = ctk.CTkButton(self.frame, width=30, text="...", command=lambda: self.browse_for_path(self.path_entry))
51 | self.path_button.grid(row=0, column=1, sticky="e", padx=5, pady=5)
52 |
53 | self.prompt_label = ctk.CTkLabel(self.frame, text="Prompt", width=100)
54 | self.prompt_label.grid(row=1, column=0, sticky="w",padx=5, pady=5)
55 | self.prompt_entry = ctk.CTkEntry(self.frame, width=200)
56 | self.prompt_entry.grid(row=1, column=1, sticky="w", padx=5, pady=5)
57 |
58 | self.mode_label = ctk.CTkLabel(self.frame, text="Mode", width=100)
59 | self.mode_label.grid(row=2, column=0, sticky="w", padx=5, pady=5)
60 | self.mode_dropdown = ctk.CTkOptionMenu(self.frame, variable=self.mode_var, values=self.modes, dynamic_resizing=False, width=200)
61 | self.mode_dropdown.grid(row=2, column=1, sticky="w", padx=5, pady=5)
62 |
63 | self.threshold_label = ctk.CTkLabel(self.frame, text="Threshold", width=100)
64 | self.threshold_label.grid(row=3, column=0, sticky="w", padx=5, pady=5)
65 | self.threshold_entry = ctk.CTkEntry(self.frame, width=200, placeholder_text="0.0 - 1.0")
66 | self.threshold_entry.insert(0, "0.3")
67 | self.threshold_entry.grid(row=3, column=1, sticky="w", padx=5, pady=5)
68 |
69 | self.smooth_label = ctk.CTkLabel(self.frame, text="Smooth", width=100)
70 | self.smooth_label.grid(row=4, column=0, sticky="w", padx=5, pady=5)
71 | self.smooth_entry = ctk.CTkEntry(self.frame, width=200, placeholder_text="5")
72 | self.smooth_entry.insert(0, 5)
73 | self.smooth_entry.grid(row=4, column=1, sticky="w", padx=5, pady=5)
74 |
75 | self.expand_label = ctk.CTkLabel(self.frame, text="Expand", width=100)
76 | self.expand_label.grid(row=5, column=0, sticky="w", padx=5, pady=5)
77 | self.expand_entry = ctk.CTkEntry(self.frame, width=200, placeholder_text="10")
78 | self.expand_entry.insert(0, 10)
79 | self.expand_entry.grid(row=5, column=1, sticky="w", padx=5, pady=5)
80 |
81 | self.progress_label = ctk.CTkLabel(self.frame, text="Progress: 0/0", width=100)
82 | self.progress_label.grid(row=6, column=0, sticky="w", padx=5, pady=5)
83 | self.progress = ctk.CTkProgressBar(self.frame, orientation="horizontal", mode="determinate", width=200)
84 | self.progress.grid(row=6, column=1, sticky="w", padx=5, pady=5)
85 |
86 | self.create_masks_button = ctk.CTkButton(self.frame, text="Create Masks", width=310, command=self.create_masks)
87 | self.create_masks_button.grid(row=7, column=0, columnspan=2, sticky="w", padx=5, pady=5)
88 |
89 | self.frame.pack(fill="both", expand=True)
90 |
91 | def browse_for_path(self, entry_box):
92 | # get the path from the user
93 | path = fd.askdirectory()
94 | # set the path to the entry box
95 | # delete entry box text
96 | entry_box.focus_set()
97 | entry_box.delete(0, tk.END)
98 | entry_box.insert(0, path)
99 | self.focus_set()
100 |
101 | def set_progress(self, value, max_value):
102 | progress = value / max_value
103 | self.progress.set(progress)
104 | self.progress_label.configure(text="{0}/{1}".format(value, max_value))
105 | self.progress.update()
106 |
107 | def create_masks(self):
108 | self.parent.load_clip_seg_model()
109 |
110 | mode = {
111 | "Replace all masks": "replace",
112 | "Create if absent": "fill",
113 | "Add to existing": "add",
114 | "Subtract from existing": "subtract"
115 | }[self.mode_var.get()]
116 |
117 | self.parent.clip_seg.mask_folder(
118 | sample_dir=self.path_entry.get(),
119 | prompts=[self.prompt_entry.get()],
120 | mode=mode,
121 | threshold=float(self.threshold_entry.get()),
122 | smooth_pixels=int(self.smooth_entry.get()),
123 | expand_pixels=int(self.expand_entry.get()),
124 | progress_callback=self.set_progress,
125 | )
126 | self.parent.load_image()
127 |
128 |
129 | def _check_file_type(f: str) -> bool:
130 | return f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp', ".bmp", ".tiff"))
131 |
132 |
133 | class ImageBrowser(ctk.CTkToplevel):
134 | def __init__(self,mainProcess=None):
135 | super().__init__()
136 | if not os.path.exists("scripts/BLIP"):
137 | print("Getting BLIP from GitHub.")
138 | subprocess.run(["git", "clone", "https://github.com/salesforce/BLIP", "scripts/BLIP"])
139 | #if not os.path.exists("scripts/CLIP"):
140 | # print("Getting CLIP from GitHub.")
141 | # subprocess.run(["git", "clone", "https://github.com/pharmapsychotic/clip-interrogator.git', 'scripts/CLIP"])
142 | blip_path = "scripts/BLIP"
143 | sys.path.append(blip_path)
144 | #clip_path = "scripts/CLIP"
145 | #sys.path.append(clip_path)
146 | self.mainProcess = mainProcess
147 | self.captioner_folder = os.path.dirname(os.path.realpath(__file__))
148 | self.clip_seg = None
149 | self.PILimage = None
150 | self.PILmask = None
151 | self.mask_draw_x = 0
152 | self.mask_draw_y = 0
153 | self.mask_draw_radius = 20
154 | #self = master
155 | #self.overrideredirect(True)
156 | #self.title_bar = TitleBar(self)
157 | #self.title_bar.pack(side="top", fill="x")
158 | #make not user resizable
159 | self.title("Caption Buddy")
160 | #self.resizable(False, False)
161 | self.geometry("720x820")
162 | self.top_frame = ctk.CTkFrame(self,fg_color='transparent')
163 | self.top_frame.pack(side="top", fill="x",expand=False)
164 | self.top_subframe = ctk.CTkFrame(self.top_frame,fg_color='transparent')
165 | self.top_subframe.pack(side="bottom", fill="x",pady=10)
166 | self.top_subframe.grid_columnconfigure(0, weight=1)
167 | self.top_subframe.grid_columnconfigure(1, weight=1)
168 | self.tip_frame = ctk.CTkFrame(self,fg_color='transparent')
169 | self.tip_frame.pack(side="top")
170 | self.dark_mode_var = "#202020"
171 | #self.dark_purple_mode_var = "#1B0F1B"
172 | self.dark_mode_title_var = "#286aff"
173 | self.dark_mode_button_pressed_var = "#BB91B6"
174 | self.dark_mode_button_var = "#8ea0e1"
175 | self.dark_mode_text_var = "#c6c7c8"
176 | #self.configure(bg_color=self.dark_mode_var)
177 | self.canvas = ctk.CTkLabel(self,text='', width=600, height=600)
178 | #self.canvas.configure(bg_color=self.dark_mode_var)
179 | #create temporary image for canvas
180 | self.canvas.pack()
181 | self.cur_img_index = 0
182 | self.image_count = 0
183 | #make a frame with a grid under the canvas
184 | self.frame = ctk.CTkFrame(self)
185 | #grid
186 | self.frame.grid_columnconfigure(0, weight=1)
187 | self.frame.grid_columnconfigure(1, weight=100)
188 | self.frame.grid_columnconfigure(2, weight=1)
189 | self.frame.grid_rowconfigure(0, weight=1)
190 |
191 | #show the frame
192 | self.frame.pack(side="bottom", fill="x")
193 | #bottom frame
194 | self.bottom_frame = ctk.CTkFrame(self)
195 | #make grid
196 | self.bottom_frame.grid_columnconfigure(0, weight=0)
197 | self.bottom_frame.grid_columnconfigure(1, weight=2)
198 | self.bottom_frame.grid_columnconfigure(2, weight=0)
199 | self.bottom_frame.grid_columnconfigure(3, weight=2)
200 | self.bottom_frame.grid_columnconfigure(4, weight=0)
201 | self.bottom_frame.grid_columnconfigure(5, weight=2)
202 | self.bottom_frame.grid_rowconfigure(0, weight=1)
203 | #show the frame
204 | self.bottom_frame.pack(side="bottom", fill="x")
205 |
206 | self.image_index = 0
207 | self.image_list = []
208 | self.caption = ''
209 | self.caption_file = ''
210 | self.caption_file_path = ''
211 | self.caption_file_name = ''
212 | self.caption_file_ext = ''
213 | self.caption_file_name_no_ext = ''
214 | self.output_format='text'
215 | #check if bad_files.txt exists
216 | if os.path.exists("bad_files.txt"):
217 | #delete it
218 | os.remove("bad_files.txt")
219 | self.use_blip = True
220 | self.debug = False
221 | self.create_widgets()
222 | self.load_blip_model()
223 | self.load_options()
224 | #self.open_folder()
225 |
226 | self.canvas.focus_force()
227 | self.canvas.bind("", self.next_image)
228 | self.canvas.bind("", self.prev_image)
229 | #on close window
230 | self.protocol("WM_DELETE_WINDOW", self.on_closing)
231 | def on_closing(self):
232 | #self.save_options()
233 | self.mainProcess.deiconify()
234 | self.destroy()
235 | def create_widgets(self):
236 | self.output_folder = ''
237 |
238 | # add a checkbox to toggle auto generate caption
239 | self.auto_generate_caption = tk.BooleanVar(self.top_subframe)
240 | self.auto_generate_caption.set(True)
241 | self.auto_generate_caption_checkbox = ctk.CTkCheckBox(self.top_subframe, text="Auto Generate Caption", variable=self.auto_generate_caption,width=50)
242 | self.auto_generate_caption_checkbox.pack(side="left", fill="x", expand=True, padx=10)
243 |
244 | # add a checkbox to skip auto generating captions if they already exist
245 | self.auto_generate_caption_text_override = tk.BooleanVar(self.top_subframe)
246 | self.auto_generate_caption_text_override.set(False)
247 | self.auto_generate_caption_checkbox_text_override = ctk.CTkCheckBox(self.top_subframe, text="Skip Auto Generate If Text Caption Exists", variable=self.auto_generate_caption_text_override,width=50)
248 | self.auto_generate_caption_checkbox_text_override.pack(side="left", fill="x", expand=True, padx=10)
249 |
250 | # add a checkbox to enable mask editing
251 | self.enable_mask_editing = tk.BooleanVar(self.top_subframe)
252 | self.enable_mask_editing.set(False)
253 | self.enable_mask_editing_checkbox = ctk.CTkCheckBox(self.top_subframe, text="Enable Mask Editing", variable=self.enable_mask_editing, width=50)
254 | self.enable_mask_editing_checkbox.pack(side="left", fill="x", expand=True, padx=10)
255 |
256 | self.open_button = ctk.CTkButton(self.top_frame,text="Load Folder",fg_color=("gray75", "gray25"), command=self.open_folder,width=50)
257 | #self.open_button.grid(row=0, column=1)
258 | self.open_button.pack(side="left", fill="x",expand=True,padx=10)
259 | #add a batch folder button
260 | self.batch_folder_caption_button = ctk.CTkButton(self.top_frame, text="Batch Folder Caption", fg_color=("gray75", "gray25"), command=self.batch_folder_caption, width=50)
261 | self.batch_folder_caption_button.pack(side="left", fill="x", expand=True, padx=10)
262 | self.batch_folder_mask_button = ctk.CTkButton(self.top_frame, text="Batch Folder Mask", fg_color=("gray75", "gray25"), command=self.batch_folder_mask, width=50)
263 | self.batch_folder_mask_button.pack(side="left", fill="x", expand=True, padx=10)
264 |
265 | #add an options button to the same row as the open button
266 | self.options_button = ctk.CTkButton(self.top_frame, text="Options",fg_color=("gray75", "gray25"), command=self.open_options,width=50)
267 | self.options_button.pack(side="left", fill="x",expand=True,padx=10)
268 | #add generate caption button
269 | self.generate_caption_button = ctk.CTkButton(self.top_frame, text="Generate Caption",fg_color=("gray75", "gray25"), command=self.generate_caption,width=50)
270 | self.generate_caption_button.pack(side="left", fill="x",expand=True,padx=10)
271 |
272 | #add a label for tips under the buttons
273 | self.tips_label = ctk.CTkLabel(self.tip_frame, text="Use Alt with left and right arrow keys to navigate images, enter to save the caption.")
274 | self.tips_label.pack(side="top")
275 | #add image count label
276 | self.image_count_label = ctk.CTkLabel(self.tip_frame, text=f"Image {self.cur_img_index} of {self.image_count}")
277 | self.image_count_label.pack(side="top")
278 |
279 | self.image_label = ctk.CTkLabel(self.canvas,text='',width=100,height=100)
280 | self.image_label.grid(row=0, column=0, sticky="nsew")
281 | #self.image_label.bind("", self.click_canvas)
282 | self.image_label.bind("", self.draw_mask)
283 | self.image_label.bind("", self.draw_mask)
284 | self.image_label.bind("", self.draw_mask)
285 | self.image_label.bind("", self.draw_mask_radius)
286 | #self.image_label.pack(side="top")
287 | #previous button
288 | self.prev_button = ctk.CTkButton(self.frame,text="Previous", command= lambda event=None: self.prev_image(event),width=50)
289 | #grid
290 | self.prev_button.grid(row=1, column=0, sticky="w",padx=5,pady=10)
291 | #self.prev_button.pack(side="left")
292 | #self.prev_button.bind("", self.prev_image)
293 | self.caption_entry = ctk.CTkEntry(self.frame)
294 | #grid
295 | self.caption_entry.grid(row=1, column=1, rowspan=3, sticky="nsew",pady=10)
296 | #bind to enter key
297 | self.caption_entry.bind("", self.save)
298 | self.canvas.bind("", self.save)
299 | self.caption_entry.bind("", self.next_image)
300 | self.caption_entry.bind("", self.prev_image)
301 | self.caption_entry.bind("", self.delete_word)
302 | #next button
303 |
304 | self.next_button = ctk.CTkButton(self.frame,text='Next', command= lambda event=None: self.next_image(event),width=50)
305 | #self.next_button["text"] = "Next"
306 | #grid
307 | self.next_button.grid(row=1, column=2, sticky="e",padx=5,pady=10)
308 | #add two entry boxes and labels in the style of :replace _ with _
309 | #create replace string variable
310 | self.replace_label = ctk.CTkLabel(self.bottom_frame, text="Replace:")
311 | self.replace_label.grid(row=0, column=0, sticky="w",padx=5)
312 | self.replace_entry = ctk.CTkEntry(self.bottom_frame, )
313 | self.replace_entry.grid(row=0, column=1, sticky="nsew",padx=5)
314 | self.replace_entry.bind("", self.save)
315 | #self.replace_entry.bind("", self.replace)
316 | #with label
317 | #create with string variable
318 | self.with_label = ctk.CTkLabel(self.bottom_frame, text="With:")
319 | self.with_label.grid(row=0, column=2, sticky="w",padx=5)
320 | self.with_entry = ctk.CTkEntry(self.bottom_frame, )
321 | self.with_entry.grid(row=0, column=3, sticky="nswe",padx=5)
322 | self.with_entry.bind("", self.save)
323 | #add another entry with label, add suffix
324 |
325 | #create prefix string var
326 | self.prefix_label = ctk.CTkLabel(self.bottom_frame, text="Add to start:")
327 | self.prefix_label.grid(row=0, column=4, sticky="w",padx=5)
328 | self.prefix_entry = ctk.CTkEntry(self.bottom_frame, )
329 | self.prefix_entry.grid(row=0, column=5, sticky="nsew",padx=5)
330 | self.prefix_entry.bind("", self.save)
331 |
332 | #create suffix string var
333 | self.suffix_label = ctk.CTkLabel(self.bottom_frame, text="Add to end:")
334 | self.suffix_label.grid(row=0, column=6, sticky="w",padx=5)
335 | self.suffix_entry = ctk.CTkEntry(self.bottom_frame, )
336 | self.suffix_entry.grid(row=0, column=7, sticky="nsew",padx=5)
337 | self.suffix_entry.bind("", self.save)
338 | self.all_entries = [self.replace_entry, self.with_entry, self.suffix_entry, self.caption_entry, self.prefix_entry]
339 | #bind right click menu to all entries
340 | for entry in self.all_entries:
341 | entry.bind("", self.create_right_click_menu)
342 | def batch_folder_caption(self):
343 | #show imgs in folder askdirectory
344 | #ask user if to batch current folder or select folder
345 | #if bad_files.txt exists, delete it
346 | self.bad_files = []
347 | if os.path.exists('bad_files.txt'):
348 | os.remove('bad_files.txt')
349 | try:
350 | #check if self.folder is set
351 | self.folder
352 | except AttributeError:
353 | self.folder = ''
354 | if self.folder == '':
355 | self.folder = fd.askdirectory(title="Select Folder to Batch Process", initialdir=os.getcwd())
356 | batch_input_dir = self.folder
357 | else:
358 | ask = tk.messagebox.askquestion("Batch Folder", "Batch current folder?")
359 | if ask == 'yes':
360 | batch_input_dir = self.folder
361 | else:
362 | batch_input_dir = fd.askdirectory(title="Select Folder to Batch Process", initialdir=os.getcwd())
363 | ask2 = tk.messagebox.askquestion("Batch Folder", "Save output to same directory?")
364 | if ask2 == 'yes':
365 | batch_output_dir = batch_input_dir
366 | else:
367 | batch_output_dir = fd.askdirectory(title="Select Folder to Save Batch Processed Images", initialdir=os.getcwd())
368 | if batch_input_dir == '':
369 | return
370 | if batch_output_dir == '':
371 | batch_output_dir = batch_input_dir
372 |
373 | self.caption_file_name = os.path.basename(batch_input_dir)
374 | self.image_list = []
375 | for file in os.listdir(batch_input_dir):
376 | if _check_file_type(file) and not file.endswith('-masklabel.png'):
377 | self.image_list.append(os.path.join(batch_input_dir, file))
378 | self.image_index = 0
379 | #use progress bar class
380 | #pba = tk.Tk()
381 | #pba.title("Batch Processing")
382 | #remove icon
383 | #pba.wm_attributes('-toolwindow','True')
384 | pb = ProgressbarWithCancel(max=len(self.image_list))
385 | #pb.set_max(len(self.image_list))
386 | pb.set_progress(0)
387 |
388 | #if batch_output_dir doesn't exist, create it
389 | if not os.path.exists(batch_output_dir):
390 | os.makedirs(batch_output_dir)
391 | for i in range(len(self.image_list)):
392 | radnom_chance = random.randint(0,25)
393 | if radnom_chance == 0:
394 | pb.set_random_label()
395 | if pb.is_cancelled():
396 | pb.destroy()
397 | return
398 | self.image_index = i
399 | #get float value of progress between 0 and 1 according to the image index and the total number of images
400 | progress = i / len(self.image_list)
401 | pb.set_progress(progress)
402 | self.update()
403 | try:
404 | img = Image.open(self.image_list[i]).convert("RGB")
405 | except:
406 | self.bad_files.append(self.image_list[i])
407 | #skip file
408 | continue
409 | tensor = transforms.Compose([
410 | transforms.Resize((self.blipSize, self.blipSize), interpolation=InterpolationMode.BICUBIC),
411 | transforms.ToTensor(),
412 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
413 | ])
414 | torch_image = tensor(img).unsqueeze(0).to(torch.device("cuda"))
415 | if self.nucleus_sampling:
416 | captions = self.blip_decoder.generate(torch_image, sample=True, top_p=self.q_factor)
417 | else:
418 | captions = self.blip_decoder.generate(torch_image, sample=False, num_beams=16, min_length=self.min_length, \
419 | max_length=48, repetition_penalty=self.q_factor)
420 | caption = captions[0]
421 | self.replace = self.replace_entry.get()
422 | self.replace_with = self.with_entry.get()
423 | self.suffix_var = self.suffix_entry.get()
424 | self.prefix = self.prefix_entry.get()
425 | #prepare the caption
426 | if self.suffix_var.startswith(',') or self.suffix_var.startswith(' '):
427 | self.suffix_var = self.suffix_var
428 | else:
429 | self.suffix_var = ' ' + self.suffix_var
430 | caption = caption.replace(self.replace, self.replace_with)
431 | if self.prefix != '':
432 | if self.prefix.endswith(' '):
433 | self.prefix = self.prefix[:-1]
434 | if not self.prefix.endswith(','):
435 | self.prefix = self.prefix+','
436 | caption = self.prefix + ' ' + caption
437 | if caption.endswith(',') or caption.endswith('.'):
438 | caption = caption[:-1]
439 | caption = caption +', ' + self.suffix_var
440 | else:
441 | caption = caption + self.suffix_var
442 | #saving the captioned image
443 | if self.output_format == 'text':
444 | #text file with same name as image
445 | imgName = os.path.basename(self.image_list[self.image_index])
446 | imgName = imgName[:imgName.rfind('.')]
447 | caption_file = os.path.join(batch_output_dir, imgName + '.txt')
448 | with open(caption_file, 'w') as f:
449 | f.write(caption)
450 | elif self.output_format == 'filename':
451 | #duplicate image with caption as file name
452 | img.save(os.path.join(batch_output_dir, caption+'.png'))
453 | progress = i + 1 / len(self.image_list)
454 | pb.set_progress(progress)
455 | #show message box when done
456 | pb.destroy()
457 | donemsg = tk.messagebox.showinfo("Batch Folder", "Batching complete!",parent=self.master)
458 | if len(self.bad_files) > 0:
459 | bad_files_msg = tk.messagebox.showinfo("Bad Files", "Couldn't process " + str(len(self.bad_files)) + "files,\nFor a list of problematic files see bad_files.txt",parent=self.master)
460 | with open('bad_files.txt', 'w') as f:
461 | for item in self.bad_files:
462 | f.write(item + '\n')
463 |
464 | #ask user if we should load the batch output folder
465 | ask3 = tk.messagebox.askquestion("Batch Folder", "Load batch output folder?")
466 | if ask3 == 'yes':
467 | self.image_index = 0
468 | self.open_folder(folder=batch_output_dir)
469 | #focus on donemsg
470 | #donemsg.focus_force()
471 | def generate_caption(self):
472 | #get the image
473 | tensor = transforms.Compose([
474 | #transforms.CenterCrop(SIZE),
475 | transforms.Resize((self.blipSize, self.blipSize), interpolation=InterpolationMode.BICUBIC),
476 | transforms.ToTensor(),
477 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
478 | ])
479 | torch_image = tensor(self.PILimage).unsqueeze(0).to(torch.device("cuda"))
480 | if self.nucleus_sampling:
481 | captions = self.blip_decoder.generate(torch_image, sample=True, top_p=self.q_factor)
482 | else:
483 | captions = self.blip_decoder.generate(torch_image, sample=False, num_beams=16, min_length=self.min_length, \
484 | max_length=48, repetition_penalty=self.q_factor)
485 | self.caption = captions[0]
486 | self.caption_entry.delete(0, tk.END)
487 | self.caption_entry.insert(0, self.caption)
488 | #change the caption entry color to red
489 | self.caption_entry.configure(fg_color='red')
490 | def load_blip_model(self):
491 | self.blipSize = 384
492 | blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
493 | #check if options file exists
494 | if os.path.exists(os.path.join(self.captioner_folder, 'options.json')):
495 | with open(os.path.join(self.captioner_folder, 'options.json'), 'r') as f:
496 | self.nucleus_sampling = json.load(f)['nucleus_sampling']
497 | self.q_factor = json.load(f)['q_factor']
498 | self.min_length = json.load(f)['min_length']
499 | else:
500 | self.nucleus_sampling = False
501 | self.q_factor = 1.0
502 | self.min_length = 22
503 | config_path = os.path.join(self.captioner_folder, "BLIP/configs/med_config.json")
504 | cache_folder = os.path.join(self.captioner_folder, "BLIP/cache")
505 | model_path = os.path.join(self.captioner_folder, "BLIP/models/model_base_caption_capfilt_large.pth")
506 | if not os.path.exists(cache_folder):
507 | os.makedirs(cache_folder)
508 |
509 | if not os.path.exists(model_path):
510 | print(f"Downloading BLIP to {cache_folder}")
511 | with requests.get(blip_model_url, stream=True) as session:
512 | session.raise_for_status()
513 | with open(model_path, 'wb') as f:
514 | for chunk in session.iter_content(chunk_size=1024):
515 | f.write(chunk)
516 | print('Download complete')
517 | else:
518 | print(f"Found BLIP model")
519 | import models.blip
520 | blip_decoder = models.blip.blip_decoder(pretrained=model_path, image_size=self.blipSize, vit='base', med_config=config_path)
521 | blip_decoder.eval()
522 | self.blip_decoder = blip_decoder.to(torch.device("cuda"))
523 |
524 | def batch_folder_mask(self):
525 | folder = ''
526 | try:
527 | # check if self.folder is set
528 | folder = self.folder
529 | except:
530 | pass
531 |
532 | dialog = BatchMaskWindow(self, folder)
533 | dialog.mainloop()
534 |
535 | def load_clip_seg_model(self):
536 | if self.clip_seg is None:
537 | self.clip_seg = ClipSeg()
538 |
539 | def open_folder(self,folder=None):
540 | if folder is None:
541 | self.folder = fd.askdirectory()
542 | else:
543 | self.folder = folder
544 | if self.folder == '':
545 | return
546 | self.output_folder = self.folder
547 | self.image_list = [os.path.join(self.folder, f) for f in os.listdir(self.folder) if _check_file_type(f) and not f.endswith('-masklabel.png') and not f.endswith('-depth.png')]
548 | #self.image_list.sort()
549 | #sort the image list alphabetically so that the images are in the same order every time
550 | self.image_list.sort(key=lambda x: x.lower())
551 |
552 | self.image_count = len(self.image_list)
553 | if self.image_count == 0:
554 | tk.messagebox.showinfo("No Images", "No images found in the selected folder")
555 | return
556 | #update the image count label
557 |
558 | self.image_index = 0
559 | self.image_count_label.configure(text=f'Image {self.image_index+1} of {self.image_count}')
560 | self.output_folder = self.folder
561 | self.load_image()
562 | self.caption_entry.focus_set()
563 |
564 | def draw_mask(self, event):
565 | if not self.enable_mask_editing.get():
566 | return
567 |
568 | if event.widget != self.image_label.children["!label"]:
569 | return
570 |
571 | start_x = int(event.x / self.image_size[0] * self.PILimage.width)
572 | start_y = int(event.y / self.image_size[1] * self.PILimage.height)
573 | end_x = int(self.mask_draw_x / self.image_size[0] * self.PILimage.width)
574 | end_y = int(self.mask_draw_y / self.image_size[1] * self.PILimage.height)
575 |
576 | self.mask_draw_x = event.x
577 | self.mask_draw_y = event.y
578 |
579 | color = None
580 |
581 | if event.state & 0x0100 or event.num == 1: # left mouse button
582 | color = (255, 255, 255)
583 | elif event.state & 0x0400 or event.num == 3: # right mouse button
584 | color = (0, 0, 0)
585 |
586 | if color is not None:
587 | if self.PILmask is None:
588 | self.PILmask = Image.new('RGB', size=self.PILimage.size, color=(0, 0, 0))
589 |
590 | draw = ImageDraw.Draw(self.PILmask)
591 | draw.line((start_x, start_y, end_x, end_y), fill=color, width=self.mask_draw_radius + self.mask_draw_radius + 1)
592 | draw.ellipse((start_x - self.mask_draw_radius, start_y - self.mask_draw_radius, start_x + self.mask_draw_radius, start_y + self.mask_draw_radius), fill=color, outline=None)
593 | draw.ellipse((end_x - self.mask_draw_radius, end_y - self.mask_draw_radius, end_x + self.mask_draw_radius, end_y + self.mask_draw_radius), fill=color, outline=None)
594 |
595 | self.compose_masked_image()
596 | self.display_image()
597 |
598 | def draw_mask_radius(self, event):
599 | if event.widget != self.image_label.children["!label"]:
600 | return
601 |
602 | delta = -np.sign(event.delta) * 5
603 | self.mask_draw_radius += delta
604 |
605 | def compose_masked_image(self):
606 | np_image = np.array(self.PILimage).astype(np.float32) / 255.0
607 | np_mask = np.array(self.PILmask).astype(np.float32) / 255.0
608 | np_mask = np.clip(np_mask, 0.4, 1.0)
609 | np_masked_image = (np_image * np_mask * 255.0).astype(np.uint8)
610 | self.image = Image.fromarray(np_masked_image, mode='RGB')
611 |
612 | def display_image(self):
613 | #resize to fit 600x600 while maintaining aspect ratio
614 | width, height = self.image.size
615 | if width > height:
616 | new_width = 600
617 | new_height = int(600 * height / width)
618 | else:
619 | new_height = 600
620 | new_width = int(600 * width / height)
621 | self.image_size = (new_width, new_height)
622 | self.image = self.image.resize(self.image_size, Image.Resampling.LANCZOS)
623 | self.image = ctk.CTkImage(self.image, size=self.image_size)
624 | self.image_label.configure(image=self.image)
625 |
626 | def load_image(self):
627 | try:
628 | self.PILimage = Image.open(self.image_list[self.image_index]).convert('RGB')
629 | except:
630 | print(f'Error opening image {self.image_list[self.image_index]}')
631 | print('Logged path to bad_files.txt')
632 | #if bad_files.txt doesn't exist, create it
633 | if not os.path.exists('bad_files.txt'):
634 | with open('bad_files.txt', 'w') as f:
635 | f.write(self.image_list[self.image_index]+'\n')
636 | else:
637 | with open('bad_files.txt', 'a') as f:
638 | f.write(self.image_list[self.image_index]+'\n')
639 | return
640 |
641 | self.image = self.PILimage.copy()
642 |
643 | try:
644 | self.PILmask = None
645 | mask_filename = os.path.splitext(self.image_list[self.image_index])[0] + '-masklabel.png'
646 | if os.path.exists(mask_filename):
647 | self.PILmask = Image.open(mask_filename).convert('RGB')
648 | self.compose_masked_image()
649 | except Exception as e:
650 | print(f'Error opening mask for {self.image_list[self.image_index]}')
651 | print('Logged path to bad_files.txt')
652 | #if bad_files.txt doesn't exist, create it
653 | if not os.path.exists('bad_files.txt'):
654 | with open('bad_files.txt', 'w') as f:
655 | f.write(self.image_list[self.image_index]+'\n')
656 | else:
657 | with open('bad_files.txt', 'a') as f:
658 | f.write(self.image_list[self.image_index]+'\n')
659 | return
660 |
661 | self.display_image()
662 |
663 | self.caption_file_path = self.image_list[self.image_index]
664 | self.caption_file_name = os.path.basename(self.caption_file_path)
665 | self.caption_file_ext = os.path.splitext(self.caption_file_name)[1]
666 | self.caption_file_name_no_ext = os.path.splitext(self.caption_file_name)[0]
667 | self.caption_file = os.path.join(self.folder, self.caption_file_name_no_ext + '.txt')
668 | if os.path.isfile(self.caption_file) and self.auto_generate_caption.get() == False or os.path.isfile(self.caption_file) and self.auto_generate_caption.get() == True and self.auto_generate_caption_text_override.get() == True:
669 | with open(self.caption_file, 'r') as f:
670 | self.caption = f.read()
671 | self.caption_entry.delete(0, tk.END)
672 | self.caption_entry.insert(0, self.caption)
673 | self.caption_entry.configure(fg_color=ThemeManager.theme["CTkEntry"]["fg_color"])
674 | self.use_blip = False
675 | elif os.path.isfile(self.caption_file) and self.auto_generate_caption.get() == True and self.auto_generate_caption_text_override.get() == False or os.path.isfile(self.caption_file)==False and self.auto_generate_caption.get() == True and self.auto_generate_caption_text_override.get() == True:
676 | self.use_blip = True
677 | self.caption_entry.delete(0, tk.END)
678 | elif os.path.isfile(self.caption_file) == False and self.auto_generate_caption.get() == False:
679 | self.caption_entry.delete(0, tk.END)
680 | return
681 | if self.use_blip and self.debug==False:
682 | tensor = transforms.Compose([
683 | transforms.Resize((self.blipSize, self.blipSize), interpolation=InterpolationMode.BICUBIC),
684 | transforms.ToTensor(),
685 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
686 | ])
687 | torch_image = tensor(self.PILimage).unsqueeze(0).to(torch.device("cuda"))
688 | if self.nucleus_sampling:
689 | captions = self.blip_decoder.generate(torch_image, sample=True, top_p=self.q_factor)
690 | else:
691 | captions = self.blip_decoder.generate(torch_image, sample=False, num_beams=16, min_length=self.min_length, \
692 | max_length=48, repetition_penalty=self.q_factor)
693 | self.caption = captions[0]
694 | self.caption_entry.delete(0, tk.END)
695 | self.caption_entry.insert(0, self.caption)
696 | #change the caption entry color to red
697 | self.caption_entry.configure(fg_color='red')
698 |
699 | def save(self, event):
700 | self.save_caption()
701 |
702 | if self.enable_mask_editing.get():
703 | self.save_mask()
704 |
705 | def save_mask(self):
706 | mask_filename = os.path.splitext(self.image_list[self.image_index])[0] + '-masklabel.png'
707 | if self.PILmask is not None:
708 | self.PILmask.save(mask_filename)
709 |
710 | def save_caption(self):
711 | self.caption = self.caption_entry.get()
712 | self.replace = self.replace_entry.get()
713 | self.replace_with = self.with_entry.get()
714 | self.suffix_var = self.suffix_entry.get()
715 | self.prefix = self.prefix_entry.get()
716 | #prepare the caption
717 | self.caption = self.caption.replace(self.replace, self.replace_with)
718 | if self.suffix_var.startswith(',') or self.suffix_var.startswith(' '):
719 | self.suffix_var = self.suffix_var
720 | else:
721 | self.suffix_var = ' ' + self.suffix_var
722 | if self.prefix != '':
723 | if self.prefix.endswith(' '):
724 | self.prefix = self.prefix[:-1]
725 | if not self.prefix.endswith(','):
726 | self.prefix = self.prefix+','
727 | self.caption = self.prefix + ' ' + self.caption
728 | if self.caption.endswith(',') or self.caption.endswith('.'):
729 | self.caption = self.caption[:-1]
730 | self.caption = self.caption +', ' + self.suffix_var
731 | else:
732 | self.caption = self.caption + self.suffix_var
733 | self.caption = self.caption.strip()
734 | if self.output_folder != self.folder:
735 | outputFolder = self.output_folder
736 | else:
737 | outputFolder = self.folder
738 | if self.output_format == 'text':
739 | #text file with same name as image
740 | #image name
741 | #print('test')
742 | imgName = os.path.basename(self.image_list[self.image_index])
743 | imgName = imgName[:imgName.rfind('.')]
744 | self.caption_file = os.path.join(outputFolder, imgName + '.txt')
745 | with open(self.caption_file, 'w') as f:
746 | f.write(self.caption)
747 | elif self.output_format == 'filename':
748 | #duplicate image with caption as file name
749 | #make sure self.caption doesn't contain any illegal characters
750 | illegal_chars = ['/', '\\', ':', '*', '?', '"', "'",'<', '>', '|', '.']
751 | for char in illegal_chars:
752 | self.caption = self.caption.replace(char, '')
753 | self.PILimage.save(os.path.join(outputFolder, self.caption+'.png'))
754 | self.caption_entry.delete(0, tk.END)
755 | self.caption_entry.insert(0, self.caption)
756 | self.caption_entry.configure(fg_color='green')
757 |
758 | self.caption_entry.focus_force()
759 | def delete_word(self,event):
760 | ent = event.widget
761 | end_idx = ent.index(tk.INSERT)
762 | start_idx = ent.get().rfind(" ", None, end_idx)
763 | ent.selection_range(start_idx, end_idx)
764 | def prev_image(self, event):
765 | if self.image_index > 0:
766 | self.image_index -= 1
767 | self.image_count_label.configure(text=f'Image {self.image_index+1} of {self.image_count}')
768 | self.load_image()
769 | self.caption_entry.focus_set()
770 | self.caption_entry.focus_force()
771 | def next_image(self, event):
772 | if self.image_index < len(self.image_list) - 1:
773 | self.image_index += 1
774 | self.image_count_label.configure(text=f'Image {self.image_index+1} of {self.image_count}')
775 | self.load_image()
776 | self.caption_entry.focus_set()
777 | self.caption_entry.focus_force()
778 | def open_options(self):
779 | self.options_window = ctk.CTkToplevel(self)
780 | self.options_window.title("Options")
781 | self.options_window.geometry("320x550")
782 | #disable reszie
783 | self.options_window.resizable(False, False)
784 | self.options_window.focus_force()
785 | self.options_window.grab_set()
786 | self.options_window.transient(self)
787 | self.options_window.protocol("WM_DELETE_WINDOW", self.close_options)
788 | #add title label
789 | self.options_title_label = ctk.CTkLabel(self.options_window, text="Options",font=ctk.CTkFont(size=20, weight="bold"))
790 | self.options_title_label.pack(side="top", pady=5)
791 | #add an entry with a button to select a folder as output folder
792 | self.output_folder_label = ctk.CTkLabel(self.options_window, text="Output Folder")
793 | self.output_folder_label.pack(side="top", pady=5)
794 | self.output_folder_entry = ctk.CTkEntry(self.options_window)
795 | self.output_folder_entry.pack(side="top", fill="x", expand=False,padx=15, pady=5)
796 | self.output_folder_entry.insert(0, self.output_folder)
797 | self.output_folder_button = ctk.CTkButton(self.options_window, text="Select Folder", command=self.select_output_folder,fg_color=("gray75", "gray25"))
798 | self.output_folder_button.pack(side="top", pady=5)
799 | #add radio buttons to select the output format between text and filename
800 | self.output_format_label = ctk.CTkLabel(self.options_window, text="Output Format")
801 | self.output_format_label.pack(side="top", pady=5)
802 | self.output_format_var = tk.StringVar(self.options_window)
803 | self.output_format_var.set(self.output_format)
804 | self.output_format_text = ctk.CTkRadioButton(self.options_window, text="Text File", variable=self.output_format_var, value="text")
805 | self.output_format_text.pack(side="top", pady=5)
806 | self.output_format_filename = ctk.CTkRadioButton(self.options_window, text="File name", variable=self.output_format_var, value="filename")
807 | self.output_format_filename.pack(side="top", pady=5)
808 | #add BLIP settings section
809 | self.blip_settings_label = ctk.CTkLabel(self.options_window, text="BLIP Settings",font=ctk.CTkFont(size=20, weight="bold"))
810 | self.blip_settings_label.pack(side="top", pady=10)
811 | #add a checkbox to use nucleas sampling or not
812 | self.nucleus_sampling_var = tk.IntVar(self.options_window)
813 | self.nucleus_sampling_checkbox = ctk.CTkCheckBox(self.options_window, text="Use nucleus sampling", variable=self.nucleus_sampling_var)
814 | self.nucleus_sampling_checkbox.pack(side="top", pady=5)
815 | if self.debug:
816 | self.nucleus_sampling = 0
817 | self.q_factor = 0.5
818 | self.min_length = 10
819 | self.nucleus_sampling_var.set(self.nucleus_sampling)
820 | #add a float entry to set the q factor
821 | self.q_factor_label = ctk.CTkLabel(self.options_window, text="Q Factor")
822 | self.q_factor_label.pack(side="top", pady=5)
823 | self.q_factor_entry = ctk.CTkEntry(self.options_window)
824 | self.q_factor_entry.insert(0, self.q_factor)
825 | self.q_factor_entry.pack(side="top", pady=5)
826 | #add a int entry to set the number minimum length
827 | self.min_length_label = ctk.CTkLabel(self.options_window, text="Minimum Length")
828 | self.min_length_label.pack(side="top", pady=5)
829 | self.min_length_entry = ctk.CTkEntry(self.options_window)
830 | self.min_length_entry.insert(0, self.min_length)
831 | self.min_length_entry.pack(side="top", pady=5)
832 | #add a horozontal radio button to select between None, ViT-L-14/openai, ViT-H-14/laion2b_s32b_b79k
833 | #self.model_label = ctk.CTkLabel(self.options_window, text="CLIP Interrogation")
834 | #self.model_label.pack(side="top")
835 | #self.model_var = tk.StringVar(self.options_window)
836 | #self.model_var.set(self.model)
837 | #self.model_none = tk.Radiobutton(self.options_window, text="None", variable=self.model_var, value="None")
838 | #self.model_none.pack(side="top")
839 | #self.model_vit_l_14 = tk.Radiobutton(self.options_window, text="ViT-L-14/openai", variable=self.model_var, value="ViT-L-14/openai")
840 | #self.model_vit_l_14.pack(side="top")
841 | #self.model_vit_h_14 = tk.Radiobutton(self.options_window, text="ViT-H-14/laion2b_s32b_b79k", variable=self.model_var, value="ViT-H-14/laion2b_s32b_b79k")
842 | #self.model_vit_h_14.pack(side="top")
843 |
844 | #add a save button
845 | self.save_button = ctk.CTkButton(self.options_window, text="Save", command=self.save_options, fg_color=("gray75", "gray25"))
846 | self.save_button.pack(side="top",fill='x',pady=10,padx=10)
847 | #all entries list
848 | entries = [self.output_folder_entry, self.q_factor_entry, self.min_length_entry]
849 | #bind the right click to all entries
850 | for entry in entries:
851 | entry.bind("", self.create_right_click_menu)
852 | self.options_file = os.path.join(self.captioner_folder, 'captioner_options.json')
853 | if os.path.isfile(self.options_file):
854 | with open(self.options_file, 'r') as f:
855 | self.options = json.load(f)
856 | self.output_folder_entry.delete(0, tk.END)
857 | self.output_folder_entry.insert(0, self.output_folder)
858 | self.output_format_var.set(self.options['output_format'])
859 | self.nucleus_sampling_var.set(self.options['nucleus_sampling'])
860 | self.q_factor_entry.delete(0, tk.END)
861 | self.q_factor_entry.insert(0, self.options['q_factor'])
862 | self.min_length_entry.delete(0, tk.END)
863 | self.min_length_entry.insert(0, self.options['min_length'])
864 | def load_options(self):
865 | self.options_file = os.path.join(self.captioner_folder, 'captioner_options.json')
866 | if os.path.isfile(self.options_file):
867 | with open(self.options_file, 'r') as f:
868 | self.options = json.load(f)
869 | #self.output_folder = self.folder
870 | #self.output_folder = self.options['output_folder']
871 | if 'folder' in self.__dict__:
872 | self.output_folder = self.folder
873 | else:
874 | self.output_folder = ''
875 | self.output_format = self.options['output_format']
876 | self.nucleus_sampling = self.options['nucleus_sampling']
877 | self.q_factor = self.options['q_factor']
878 | self.min_length = self.options['min_length']
879 | else:
880 | #if self has folder, use it, otherwise use the current folder
881 | if 'folder' in self.__dict__ :
882 | self.output_folder = self.folder
883 | else:
884 | self.output_folder = ''
885 | self.output_format = "text"
886 | self.nucleus_sampling = False
887 | self.q_factor = 0.9
888 | self.min_length =22
889 | def save_options(self):
890 | self.output_folder = self.output_folder_entry.get()
891 | self.output_format = self.output_format_var.get()
892 | self.nucleus_sampling = self.nucleus_sampling_var.get()
893 | self.q_factor = float(self.q_factor_entry.get())
894 | self.min_length = int(self.min_length_entry.get())
895 | #save options to a file
896 | self.options_file = os.path.join(self.captioner_folder, 'captioner_options.json')
897 | with open(self.options_file, 'w') as f:
898 | json.dump({'output_folder': self.output_folder, 'output_format': self.output_format, 'nucleus_sampling': self.nucleus_sampling, 'q_factor': self.q_factor, 'min_length': self.min_length}, f)
899 | self.close_options()
900 |
901 | def select_output_folder(self):
902 | self.output_folder = fd.askdirectory()
903 | self.output_folder_entry.delete(0, tk.END)
904 | self.output_folder_entry.insert(0, self.output_folder)
905 | def close_options(self):
906 | self.options_window.destroy()
907 | self.caption_entry.focus_force()
908 | def create_right_click_menu(self, event):
909 | #create a menu
910 | self.menu = Menu(self, tearoff=0)
911 | #add commands to the menu
912 | self.menu.add_command(label="Cut", command=lambda: self.focus_get().event_generate("<>"))
913 | self.menu.add_command(label="Copy", command=lambda: self.focus_get().event_generate("<>"))
914 | self.menu.add_command(label="Paste", command=lambda: self.focus_get().event_generate("<>"))
915 | self.menu.add_command(label="Select All", command=lambda: self.focus_get().event_generate("<>"))
916 | #display the menu
917 | try:
918 | self.menu.tk_popup(event.x_root, event.y_root)
919 | finally:
920 | #make sure to release the grab (Tk 8.0a1 only)
921 | self.menu.grab_release()
922 |
923 |
924 | #progress bar class with cancel button
925 | class ProgressbarWithCancel(ctk.CTkToplevel):
926 | def __init__(self,max=None, **kw):
927 | super().__init__(**kw)
928 | self.title("Batching...")
929 | self.max = max
930 | self.possibleLabels = ['Searching for answers...',"I'm working, I promise.",'ARE THOSE TENTACLES?!','Weird data man...','Another one bites the dust' ,"I think it's a cat?" ,'Looking for the meaning of life', 'Dreaming of captions']
931 |
932 | self.label = ctk.CTkLabel(self, text="Searching for answers...")
933 | self.label.pack(side="top", fill="x", expand=True,padx=10,pady=10)
934 | self.progress = ctk.CTkProgressBar(self, orientation="horizontal", mode="determinate")
935 | self.progress.pack(side="left", fill="x", expand=True,padx=10,pady=10)
936 | self.cancel_button = ctk.CTkButton(self, text="Cancel", command=self.cancel)
937 | self.cancel_button.pack(side="right",padx=10,pady=10)
938 | self.cancelled = False
939 | self.count_label = ctk.CTkLabel(self, text="0/{0}".format(self.max))
940 | self.count_label.pack(side="right",padx=10,pady=10)
941 | def set_random_label(self):
942 | import random
943 | self.label["text"] = random.choice(self.possibleLabels)
944 | #pop from list
945 | #self.possibleLabels.remove(self.label["text"])
946 | def cancel(self):
947 | self.cancelled = True
948 | def set_progress(self, value):
949 | self.progress.set(value)
950 | self.count_label.configure(text="{0}/{1}".format(int(value * self.max), self.max))
951 | def get_progress(self):
952 | return self.progress.get
953 | def set_max(self, value):
954 | return value
955 | def get_max(self):
956 | return self.progress["maximum"]
957 | def is_cancelled(self):
958 | return self.cancelled
959 | #quit the progress bar window
960 |
961 |
962 | #run when imported as a module
963 | if __name__ == "__main__":
964 |
965 | #root = tk.Tk()
966 | app = ImageBrowser()
967 | app.mainloop()
968 |
--------------------------------------------------------------------------------
/scripts/clip_segmentation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from typing import Optional, Callable
4 |
5 | import torch
6 | from PIL import Image
7 | from torch import Tensor, nn
8 | from torchvision.transforms import transforms, functional
9 | from tqdm.auto import tqdm
10 | from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
11 |
12 | DEVICE = "cuda"
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(description="ClipSeg script.")
17 | parser.add_argument(
18 | "--sample_dir",
19 | type=str,
20 | required=True,
21 | help="directory where samples are located",
22 | )
23 | parser.add_argument(
24 | "--add_prompt",
25 | type=str,
26 | required=True,
27 | action="append",
28 | help="a prompt used to create a mask",
29 | dest="prompts",
30 | )
31 | parser.add_argument(
32 | "--mode",
33 | type=str,
34 | default='fill',
35 | required=False,
36 | help="Either replace, fill, add or subtract",
37 | )
38 | parser.add_argument(
39 | "--threshold",
40 | type=float,
41 | default='0.3',
42 | required=False,
43 | help="threshold for including pixels in the mask",
44 | )
45 | parser.add_argument(
46 | "--smooth_pixels",
47 | type=int,
48 | default=5,
49 | required=False,
50 | help="radius of a smoothing operation applied to the generated mask",
51 | )
52 | parser.add_argument(
53 | "--expand_pixels",
54 | type=int,
55 | default=10,
56 | required=False,
57 | help="amount of expansion of the generated mask in all directions",
58 | )
59 |
60 | args = parser.parse_args()
61 | return args
62 |
63 |
64 | class MaskSample:
65 | def __init__(self, filename: str):
66 | self.image_filename = filename
67 | self.mask_filename = os.path.splitext(filename)[0] + "-masklabel.png"
68 |
69 | self.image = None
70 | self.mask_tensor = None
71 |
72 | self.height = 0
73 | self.width = 0
74 |
75 | self.image2Tensor = transforms.Compose([
76 | transforms.ToTensor(),
77 | ])
78 |
79 | self.tensor2Image = transforms.Compose([
80 | transforms.ToPILImage(),
81 | ])
82 |
83 | def get_image(self) -> Image:
84 | if self.image is None:
85 | self.image = Image.open(self.image_filename).convert('RGB')
86 | self.height = self.image.height
87 | self.width = self.image.width
88 |
89 | return self.image
90 |
91 | def get_mask_tensor(self) -> Tensor:
92 | if self.mask_tensor is None and os.path.exists(self.mask_filename):
93 | mask = Image.open(self.mask_filename).convert('L')
94 | mask = self.image2Tensor(mask)
95 | mask = mask.to(DEVICE)
96 | self.mask_tensor = mask.unsqueeze(0)
97 |
98 | return self.mask_tensor
99 |
100 | def set_mask_tensor(self, mask_tensor: Tensor):
101 | self.mask_tensor = mask_tensor
102 |
103 | def add_mask_tensor(self, mask_tensor: Tensor):
104 | mask = self.get_mask_tensor()
105 | if mask is None:
106 | mask = mask_tensor
107 | else:
108 | mask += mask_tensor
109 | mask = torch.clamp(mask, 0, 1)
110 |
111 | self.mask_tensor = mask
112 |
113 | def subtract_mask_tensor(self, mask_tensor: Tensor):
114 | mask = self.get_mask_tensor()
115 | if mask is None:
116 | mask = mask_tensor
117 | else:
118 | mask -= mask_tensor
119 | mask = torch.clamp(mask, 0, 1)
120 |
121 | self.mask_tensor = mask
122 |
123 | def save_mask(self):
124 | if self.mask_tensor is not None:
125 | mask = self.mask_tensor.cpu().squeeze()
126 | mask = self.tensor2Image(mask).convert('RGB')
127 | mask.save(self.mask_filename)
128 |
129 |
130 | class ClipSeg:
131 | def __init__(self):
132 | self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
133 |
134 | self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
135 | self.model.eval()
136 | self.model.to(DEVICE)
137 |
138 | self.smoothing_kernel_radius = None
139 | self.smoothing_kernel = self.__create_average_kernel(self.smoothing_kernel_radius)
140 |
141 | self.expand_kernel_radius = None
142 | self.expand_kernel = self.__create_average_kernel(self.expand_kernel_radius)
143 |
144 | @staticmethod
145 | def __create_average_kernel(kernel_radius: Optional[int]):
146 | if kernel_radius is None:
147 | return None
148 |
149 | kernel_size = kernel_radius * 2 + 1
150 | kernel_weights = torch.ones(1, 1, kernel_size, kernel_size) / (kernel_size * kernel_size)
151 | kernel = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel_size, bias=False, padding_mode='replicate', padding=kernel_radius)
152 | kernel.weight.data = kernel_weights
153 | kernel.requires_grad_(False)
154 | kernel.to(DEVICE)
155 | return kernel
156 |
157 | @staticmethod
158 | def __get_sample_filenames(sample_dir: str) -> [str]:
159 | filenames = []
160 | for filename in os.listdir(sample_dir):
161 | ext = os.path.splitext(filename)[1].lower()
162 | if ext in ['.jpg', '.jpeg', '.png', '.bmp', '.webp'] and '-masklabel.png' not in filename:
163 | filenames.append(os.path.join(sample_dir, filename))
164 |
165 | return filenames
166 |
167 | def __process_mask(self, mask: Tensor, target_height: int, target_width: int, threshold: float) -> Tensor:
168 | while len(mask.shape) < 4:
169 | mask = mask.unsqueeze(0)
170 |
171 | mask = torch.sigmoid(mask)
172 | mask = mask.sum(1).unsqueeze(1)
173 | if self.smoothing_kernel is not None:
174 | mask = self.smoothing_kernel(mask)
175 | mask = functional.resize(mask, [target_height, target_width])
176 | mask = (mask > threshold).float()
177 | if self.expand_kernel is not None:
178 | mask = self.expand_kernel(mask)
179 | mask = (mask > 0).float()
180 |
181 | return mask
182 |
183 | def mask_image(self, filename: str, prompts: [str], mode: str = 'fill', threshold: float = 0.3, smooth_pixels: int = 5, expand_pixels: int = 10):
184 | """
185 | Masks a sample
186 |
187 | Parameters:
188 | filename (`str`): a sample filename
189 | prompts (`[str]`): a list of prompts used to create a mask
190 | mode (`str`): can be one of
191 | - replace: creates new masks for all samples, even if a mask already exists
192 | - fill: creates new masks for all samples without a mask
193 | - add: adds the new region to existing masks
194 | - subtract: subtracts the new region from existing masks
195 | threshold (`float`): threshold for including pixels in the mask
196 | smooth_pixels (`int`): radius of a smoothing operation applied to the generated mask
197 | expand_pixels (`int`): amount of expansion of the generated mask in all directions
198 | """
199 |
200 | mask_sample = MaskSample(filename)
201 |
202 | if mode == 'fill' and mask_sample.get_mask_tensor() is not None:
203 | return
204 |
205 | if self.smoothing_kernel_radius != smooth_pixels:
206 | self.smoothing_kernel = self.__create_average_kernel(smooth_pixels)
207 | self.smoothing_kernel_radius = smooth_pixels
208 |
209 | if self.expand_kernel_radius != expand_pixels:
210 | self.expand_kernel = self.__create_average_kernel(expand_pixels)
211 | self.expand_kernel_radius = expand_pixels
212 |
213 | inputs = self.processor(text=prompts, images=[mask_sample.get_image()] * len(prompts), padding="max_length", return_tensors="pt")
214 | inputs.to(DEVICE)
215 | with torch.no_grad():
216 | outputs = self.model(**inputs)
217 | predicted_mask = self.__process_mask(outputs.logits, mask_sample.height, mask_sample.width, threshold)
218 |
219 | if mode == 'replace' or mode == 'fill':
220 | mask_sample.set_mask_tensor(predicted_mask)
221 | elif mode == 'add':
222 | mask_sample.add_mask_tensor(predicted_mask)
223 | elif mode == 'subtract':
224 | mask_sample.subtract_mask_tensor(predicted_mask)
225 |
226 | mask_sample.save_mask()
227 |
228 | def mask_folder(
229 | self,
230 | sample_dir: str,
231 | prompts: [str],
232 | mode: str = 'fill',
233 | threshold: float = 0.3,
234 | smooth_pixels: int = 5,
235 | expand_pixels: int = 10,
236 | progress_callback: Callable[[int, int], None] = None,
237 | error_callback: Callable[[str], None] = None,
238 | ):
239 | """
240 | Masks all samples in a folder
241 |
242 | Parameters:
243 | sample_dir (`str`): directory where samples are located
244 | prompts (`[str]`): a list of prompts used to create a mask
245 | mode (`str`): can be one of
246 | - replace: creates new masks for all samples, even if a mask already exists
247 | - fill: creates new masks for all samples without a mask
248 | - add: adds the new region to existing masks
249 | - subtract: subtracts the new region from existing masks
250 | threshold (`float`): threshold for including pixels in the mask
251 | smooth_pixels (`int`): radius of a smoothing operation applied to the generated mask
252 | expand_pixels (`int`): amount of expansion of the generated mask in all directions
253 | progress_callback (`Callable[[int, int], None]`): called after every processed image
254 | error_callback (`Callable[[str], None]`): called for every exception
255 | """
256 |
257 | filenames = self.__get_sample_filenames(sample_dir)
258 | self.mask_images(
259 | filenames=filenames,
260 | prompts=prompts,
261 | mode=mode,
262 | threshold=threshold,
263 | smooth_pixels=smooth_pixels,
264 | expand_pixels=expand_pixels,
265 | progress_callback=progress_callback,
266 | error_callback=error_callback,
267 | )
268 |
269 | def mask_images(
270 | self,
271 | filenames: [str],
272 | prompts: [str],
273 | mode: str = 'fill',
274 | threshold: float = 0.3,
275 | smooth_pixels: int = 5,
276 | expand_pixels: int = 10,
277 | progress_callback: Callable[[int, int], None] = None,
278 | error_callback: Callable[[str], None] = None,
279 | ):
280 | """
281 | Masks all samples in a list
282 |
283 | Parameters:
284 | filenames (`[str]`): a list of sample filenames
285 | prompts (`[str]`): a list of prompts used to create a mask
286 | mode (`str`): can be one of
287 | - replace: creates new masks for all samples, even if a mask already exists
288 | - fill: creates new masks for all samples without a mask
289 | - add: adds the new region to existing masks
290 | - subtract: subtracts the new region from existing masks
291 | threshold (`float`): threshold for including pixels in the mask
292 | smooth_pixels (`int`): radius of a smoothing operation applied to the generated mask
293 | expand_pixels (`int`): amount of expansion of the generated mask in all directions
294 | progress_callback (`Callable[[int, int], None]`): called after every processed image
295 | error_callback (`Callable[[str], None]`): called for every exception
296 | """
297 |
298 | if progress_callback is not None:
299 | progress_callback(0, len(filenames))
300 | for i, filename in enumerate(tqdm(filenames)):
301 | try:
302 | self.mask_image(filename, prompts, mode, threshold, smooth_pixels, expand_pixels)
303 | except Exception as e:
304 | if error_callback is not None:
305 | error_callback(filename)
306 | if progress_callback is not None:
307 | progress_callback(i + 1, len(filenames))
308 |
309 |
310 | def main():
311 | args = parse_args()
312 | clip_seg = ClipSeg()
313 | clip_seg.mask_folder(
314 | sample_dir=args.sample_dir,
315 | prompts=args.prompts,
316 | mode=args.mode,
317 | threshold=args.threshold,
318 | smooth_pixels=args.smooth_pixels,
319 | expand_pixels=args.expand_pixels,
320 | error_callback=lambda filename: print("Error while processing image " + filename)
321 | )
322 |
323 |
324 | if __name__ == "__main__":
325 | main()
326 |
--------------------------------------------------------------------------------
/scripts/convert_diffusers_to_sd_cli.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | try:
4 | import converters
5 | except ImportError:
6 |
7 | #if there's a scripts folder where the script is, add it to the path
8 | if 'scripts' in os.listdir(os.path.dirname(os.path.abspath(__file__))):
9 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '\\scripts')
10 | else:
11 | print('Could not find scripts folder. Please add it to the path manually or place this file in it.')
12 | import converters
13 |
14 |
15 | if __name__ == '__main__':
16 | args = sys.argv[1:]
17 | if len(args) != 2:
18 | print('Usage: python3 convert_diffusers_to_sd.py ')
19 | sys.exit(1)
20 | model_path = args[0]
21 | output_path = args[1]
22 | converters.Convert_Diffusers_to_SD(model_path, output_path)
23 |
--------------------------------------------------------------------------------
/scripts/converters.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import requests
16 | import os
17 | import os.path as osp
18 | import torch
19 | try:
20 | from omegaconf import OmegaConf
21 | except ImportError:
22 | raise ImportError(
23 | "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
24 | )
25 |
26 | from diffusers import (
27 | AutoencoderKL,
28 | DDIMScheduler,
29 | DPMSolverMultistepScheduler,
30 | EulerAncestralDiscreteScheduler,
31 | EulerDiscreteScheduler,
32 | HeunDiscreteScheduler,
33 | LDMTextToImagePipeline,
34 | LMSDiscreteScheduler,
35 | PNDMScheduler,
36 | StableDiffusionPipeline,
37 | UNet2DConditionModel,
38 | DiffusionPipeline
39 | )
40 | from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
41 | #from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
42 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
43 | from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer, CLIPVisionConfig, CLIPTextConfig
44 | import model_util
45 |
46 | class Convert_SD_to_Diffusers():
47 |
48 | def __init__(self, checkpoint_path, output_path, prediction_type=None, img_size=None, original_config_file=None, extract_ema=False, num_in_channels=None,pipeline_type=None,scheduler_type=None,sd_version=None,half=None,version=None):
49 | self.checkpoint_path = checkpoint_path
50 | self.output_path = output_path
51 | self.prediction_type = prediction_type
52 | self.img_size = img_size
53 | self.original_config_file = original_config_file
54 | self.extract_ema = extract_ema
55 | self.num_in_channels = num_in_channels
56 | self.pipeline_type = pipeline_type
57 | self.scheduler_type = scheduler_type
58 | self.sd_version = sd_version
59 | self.half = half
60 | self.version = version
61 | self.main()
62 |
63 |
64 | def main(self):
65 | image_size = self.img_size
66 | prediction_type = self.prediction_type
67 | original_config_file = self.original_config_file
68 | num_in_channels = self.num_in_channels
69 | scheduler_type = self.scheduler_type
70 | pipeline_type = self.pipeline_type
71 | extract_ema = self.extract_ema
72 | reference_diffusers_model = None
73 | if self.version == 'v1':
74 | is_v1 = True
75 | is_v2 = False
76 | if self.version == 'v2':
77 | is_v1 = False
78 | is_v2 = True
79 | if is_v2 == True and prediction_type == 'vprediction':
80 | reference_diffusers_model = 'stabilityai/stable-diffusion-2'
81 | if is_v2 == True and prediction_type == 'epsilon':
82 | reference_diffusers_model = 'stabilityai/stable-diffusion-2-base'
83 | if is_v1 == True and prediction_type == 'epsilon':
84 | reference_diffusers_model = 'runwayml/stable-diffusion-v1-5'
85 | dtype = 'fp16' if self.half else None
86 | v2_model = True if is_v2 else False
87 | print(f"loading model from: {self.checkpoint_path}")
88 | #print(v2_model)
89 | text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, self.checkpoint_path)
90 | print(f"copy scheduler/tokenizer config from: {reference_diffusers_model}")
91 | model_util.save_diffusers_checkpoint(v2_model, self.output_path, text_encoder, unet, reference_diffusers_model, vae)
92 | print(f"Diffusers model saved.")
93 |
94 |
95 |
96 | class Convert_Diffusers_to_SD():
97 | def __init__(self,model_path=None, output_path=None):
98 | pass
99 | def main(model_path:str, output_path:str):
100 | #print(model_path)
101 | #print(output_path)
102 | global_step = None
103 | epoch = None
104 | dtype = torch.float32
105 | pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype, tokenizer=None, safety_checker=None)
106 | text_encoder = pipe.text_encoder
107 | vae = pipe.vae
108 | unet = pipe.unet
109 | v2_model = unet.config.cross_attention_dim == 1024
110 | original_model = None
111 | key_count = model_util.save_stable_diffusion_checkpoint(v2_model, output_path, text_encoder, unet,
112 | original_model, epoch, global_step, dtype, vae)
113 | print(f"Saved model")
114 | return main(model_path, output_path)
--------------------------------------------------------------------------------
/scripts/lion_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from lion_pytorch.lion_pytorch import Lion
2 |
--------------------------------------------------------------------------------
/scripts/lion_pytorch/lion_pytorch.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Optional, Callable
2 |
3 | import torch
4 | from torch.optim.optimizer import Optimizer
5 |
6 | # functions
7 |
8 | def exists(val):
9 | return val is not None
10 |
11 | # update functions
12 |
13 | def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
14 | # stepweight decay
15 |
16 | p.data.mul_(1 - lr * wd)
17 |
18 | # weight update
19 |
20 | update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1 - beta1).sign_()
21 | p.add_(update, alpha = -lr)
22 |
23 | # decay the momentum running average coefficient
24 |
25 | exp_avg.mul_(beta2).add_(grad, alpha = 1 - beta2)
26 |
27 | # class
28 |
29 | class Lion(Optimizer):
30 | def __init__(
31 | self,
32 | params,
33 | lr: float = 1e-4,
34 | betas: Tuple[float, float] = (0.9, 0.99),
35 | weight_decay: float = 0.0,
36 | use_triton: bool = False
37 | ):
38 | assert lr > 0.
39 | assert all([0. <= beta <= 1. for beta in betas])
40 |
41 | defaults = dict(
42 | lr = lr,
43 | betas = betas,
44 | weight_decay = weight_decay
45 | )
46 |
47 | super().__init__(params, defaults)
48 |
49 | self.update_fn = update_fn
50 |
51 | if use_triton:
52 | from lion_pytorch.triton import update_fn as triton_update_fn
53 | self.update_fn = triton_update_fn
54 |
55 | @torch.no_grad()
56 | def step(
57 | self,
58 | closure: Optional[Callable] = None
59 | ):
60 |
61 | loss = None
62 | if exists(closure):
63 | with torch.enable_grad():
64 | loss = closure()
65 |
66 | for group in self.param_groups:
67 | for p in filter(lambda p: exists(p.grad), group['params']):
68 |
69 | grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p]
70 |
71 | # init state - exponential moving average of gradient values
72 |
73 | if len(state) == 0:
74 | state['exp_avg'] = torch.zeros_like(p)
75 |
76 | exp_avg = state['exp_avg']
77 |
78 | self.update_fn(
79 | p,
80 | grad,
81 | exp_avg,
82 | lr,
83 | wd,
84 | beta1,
85 | beta2
86 | )
87 |
88 | return loss
89 |
--------------------------------------------------------------------------------
/scripts/lion_pytorch/triton.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | try:
4 | import triton
5 | import triton.language as tl
6 | except ImportError as e:
7 | print('triton is not installed, please install by running `pip install triton -U --pre`')
8 | exit()
9 |
10 |
11 | @triton.autotune(configs = [
12 | triton.Config({'BLOCK_SIZE': 128}, num_warps = 4),
13 | triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8),
14 | ], key = ['n_elements'])
15 | @triton.jit
16 | def update_fn_kernel(
17 | p_ptr,
18 | grad_ptr,
19 | exp_avg_ptr,
20 | lr,
21 | wd,
22 | beta1,
23 | beta2,
24 | n_elements,
25 | BLOCK_SIZE: tl.constexpr,
26 | ):
27 | pid = tl.program_id(axis = 0)
28 |
29 | block_start = pid * BLOCK_SIZE
30 | offsets = block_start + tl.arange(0, BLOCK_SIZE)
31 |
32 | mask = offsets < n_elements
33 |
34 | # offsetted pointers
35 |
36 | offset_p_ptr = p_ptr + offsets
37 | offset_grad_ptr = grad_ptr + offsets
38 | offset_exp_avg_ptr = exp_avg_ptr + offsets
39 |
40 | # load
41 |
42 | p = tl.load(offset_p_ptr, mask = mask)
43 | grad = tl.load(offset_grad_ptr, mask = mask)
44 | exp_avg = tl.load(offset_exp_avg_ptr, mask = mask)
45 |
46 | # stepweight decay
47 |
48 | p = p * (1 - lr * wd)
49 |
50 | # diff between momentum running average and grad
51 |
52 | diff = exp_avg - grad
53 |
54 | # weight update
55 |
56 | update = diff * beta1 + grad
57 |
58 | # torch.sign
59 |
60 | can_update = update != 0
61 | update_sign = tl.where(update > 0, -lr, lr)
62 |
63 | p = p + update_sign * can_update
64 |
65 | # decay the momentum running average coefficient
66 |
67 | exp_avg = diff * beta2 + grad
68 |
69 | # store new params and momentum running average coefficient
70 |
71 | tl.store(offset_p_ptr, p, mask = mask)
72 | tl.store(offset_exp_avg_ptr, exp_avg, mask = mask)
73 |
74 | def update_fn(
75 | p: torch.Tensor,
76 | grad: torch.Tensor,
77 | exp_avg: torch.Tensor,
78 | lr: float,
79 | wd: float,
80 | beta1: float,
81 | beta2: float
82 | ):
83 | assert all([t.is_cuda for t in (p, grad, exp_avg)])
84 | n_elements = p.numel()
85 |
86 | grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
87 |
88 | update_fn_kernel[grid](
89 | p,
90 | grad,
91 | exp_avg,
92 | lr,
93 | wd,
94 | beta1,
95 | beta2,
96 | n_elements
97 | )
98 |
--------------------------------------------------------------------------------
/scripts/lora_utils.py:
--------------------------------------------------------------------------------
1 | # LoRA network module
2 | # reference:
3 | # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4 | # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5 |
6 | import math
7 | import os
8 | import torch
9 |
10 | from trainer_util import *
11 |
12 |
13 | class LoRAModule(torch.nn.Module):
14 | """
15 | replaces forward method of the original Linear, instead of replacing the original Linear module.
16 | """
17 |
18 | def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
19 | """ if alpha == 0 or None, alpha is rank (no scaling). """
20 | super().__init__()
21 | self.lora_name = lora_name
22 | self.lora_dim = lora_dim
23 |
24 | if org_module.__class__.__name__ == 'Conv2d':
25 | in_dim = org_module.in_channels
26 | out_dim = org_module.out_channels
27 | self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
28 | self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
29 | else:
30 | in_dim = org_module.in_features
31 | out_dim = org_module.out_features
32 | self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
33 | self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
34 |
35 | if type(alpha) == torch.Tensor:
36 | alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
37 | alpha = lora_dim if alpha is None or alpha == 0 else alpha
38 | self.scale = alpha / self.lora_dim
39 | self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
40 |
41 | # same as microsoft's
42 | torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
43 | torch.nn.init.zeros_(self.lora_up.weight)
44 |
45 | self.multiplier = multiplier
46 | self.org_module = org_module # remove in applying
47 |
48 | def apply_to(self):
49 | self.org_forward = self.org_module.forward
50 | self.org_module.forward = self.forward
51 | del self.org_module
52 |
53 | def forward(self, x):
54 | return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
55 |
56 |
57 | def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
58 | if network_dim is None:
59 | network_dim = 4 # default
60 | network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
61 | return network
62 |
63 |
64 | def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
65 | if os.path.splitext(file)[1] == '.safetensors':
66 | from safetensors.torch import load_file, safe_open
67 | weights_sd = load_file(file)
68 | else:
69 | weights_sd = torch.load(file, map_location='cpu')
70 |
71 | # get dim (rank)
72 | network_alpha = None
73 | network_dim = None
74 | for key, value in weights_sd.items():
75 | if network_alpha is None and 'alpha' in key:
76 | network_alpha = value
77 | if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
78 | network_dim = value.size()[0]
79 |
80 | if network_alpha is None:
81 | network_alpha = network_dim
82 |
83 | network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
84 | network.weights_sd = weights_sd
85 | return network
86 |
87 |
88 | class LoRANetwork(torch.nn.Module):
89 | UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
90 | TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
91 | LORA_PREFIX_UNET = 'lora_unet'
92 | LORA_PREFIX_TEXT_ENCODER = 'lora_te'
93 |
94 | def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
95 | super().__init__()
96 | self.multiplier = multiplier
97 | self.lora_dim = lora_dim
98 | self.alpha = alpha
99 |
100 | # create module instances
101 | def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
102 | loras = []
103 | for name, module in root_module.named_modules():
104 | if module.__class__.__name__ in target_replace_modules:
105 | for child_name, child_module in module.named_modules():
106 | if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
107 | lora_name = prefix + '.' + name + '.' + child_name
108 | lora_name = lora_name.replace('.', '_')
109 | lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
110 | loras.append(lora)
111 | return loras
112 |
113 | self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER,
114 | text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
115 | print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
116 |
117 | self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
118 | print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
119 |
120 | self.weights_sd = None
121 |
122 | # assertion
123 | names = set()
124 | for lora in self.text_encoder_loras + self.unet_loras:
125 | assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
126 | names.add(lora.lora_name)
127 |
128 | def load_weights(self, file):
129 | if os.path.splitext(file)[1] == '.safetensors':
130 | from safetensors.torch import load_file, safe_open
131 | self.weights_sd = load_file(file)
132 | else:
133 | self.weights_sd = torch.load(file, map_location='cpu')
134 |
135 | def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
136 | if self.weights_sd:
137 | weights_has_text_encoder = weights_has_unet = False
138 | for key in self.weights_sd.keys():
139 | if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
140 | weights_has_text_encoder = True
141 | elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
142 | weights_has_unet = True
143 |
144 | if apply_text_encoder is None:
145 | apply_text_encoder = weights_has_text_encoder
146 | else:
147 | assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
148 |
149 | if apply_unet is None:
150 | apply_unet = weights_has_unet
151 | else:
152 | assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
153 | else:
154 | assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
155 |
156 | if apply_text_encoder:
157 | print("enable LoRA for text encoder")
158 | else:
159 | self.text_encoder_loras = []
160 |
161 | if apply_unet:
162 | print("enable LoRA for U-Net")
163 | else:
164 | self.unet_loras = []
165 |
166 | for lora in self.text_encoder_loras + self.unet_loras:
167 | lora.apply_to()
168 | self.add_module(lora.lora_name, lora)
169 |
170 | if self.weights_sd:
171 | # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
172 | info = self.load_state_dict(self.weights_sd, False)
173 | print(f"weights are loaded: {info}")
174 |
175 | def enable_gradient_checkpointing(self):
176 | # not supported
177 | pass
178 |
179 | def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
180 | def enumerate_params(loras):
181 | params = []
182 | for lora in loras:
183 | params.extend(lora.parameters())
184 | return params
185 |
186 | self.requires_grad_(True)
187 | all_params = []
188 |
189 | if self.text_encoder_loras:
190 | param_data = {'params': enumerate_params(self.text_encoder_loras)}
191 | if text_encoder_lr is not None:
192 | param_data['lr'] = text_encoder_lr
193 | all_params.append(param_data)
194 |
195 | if self.unet_loras:
196 | param_data = {'params': enumerate_params(self.unet_loras)}
197 | if unet_lr is not None:
198 | param_data['lr'] = unet_lr
199 | all_params.append(param_data)
200 |
201 | return all_params
202 |
203 | def prepare_grad_etc(self, text_encoder, unet):
204 | self.requires_grad_(True)
205 |
206 | def on_epoch_start(self, text_encoder, unet):
207 | self.train()
208 |
209 | def get_trainable_params(self):
210 | return self.parameters()
211 |
212 | def save_weights(self, file, dtype, metadata):
213 | if metadata is not None and len(metadata) == 0:
214 | metadata = None
215 |
216 | state_dict = self.state_dict()
217 |
218 | if dtype is not None:
219 | for key in list(state_dict.keys()):
220 | v = state_dict[key]
221 | v = v.detach().clone().to("cpu").to(dtype)
222 | state_dict[key] = v
223 |
224 | if os.path.splitext(file)[1] == '.safetensors':
225 | from safetensors.torch import save_file
226 |
227 | # Precalculate model hashes to save time on indexing
228 | if metadata is None:
229 | metadata = {}
230 | model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
231 | metadata["sshs_model_hash"] = model_hash
232 | metadata["sshs_legacy_hash"] = legacy_hash
233 |
234 | save_file(state_dict, file, metadata)
235 | else:
236 | torch.save(state_dict, file)
--------------------------------------------------------------------------------
/scripts/trainer_util.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import json
3 | import math
4 | from pathlib import Path
5 | from typing import Optional
6 | import torch
7 | import torch.nn.functional as F
8 | import torch.utils.checkpoint
9 | from accelerate.logging import get_logger
10 | from accelerate.utils import set_seed
11 | from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel,DiffusionPipeline, DPMSolverMultistepScheduler,EulerDiscreteScheduler
12 | from diffusers.optimization import get_scheduler
13 | from huggingface_hub import HfFolder, Repository, whoami
14 | from torchvision import transforms
15 | from tqdm.auto import tqdm
16 | from typing import Dict, List, Generator, Tuple
17 | from PIL import Image, ImageFile
18 | from collections.abc import Iterable
19 | from trainer_util import *
20 | from dataloaders_util import *
21 |
22 | # FlashAttention based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main
23 | # /memory_efficient_attention_pytorch/flash_attention.py LICENSE MIT
24 | # https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE constants
25 | EPSILON = 1e-6
26 |
27 | class bcolors:
28 | HEADER = '\033[95m'
29 | OKBLUE = '\033[94m'
30 | OKCYAN = '\033[96m'
31 | OKGREEN = '\033[92m'
32 | WARNING = '\033[93m'
33 | FAIL = '\033[91m'
34 | ENDC = '\033[0m'
35 | BOLD = '\033[1m'
36 | UNDERLINE = '\033[4m'
37 | # helper functions
38 | def print_instructions():
39 | print(f"{bcolors.WARNING}Use 'CTRL+SHIFT+G' to open up a GUI to play around with the model (will pause training){bcolors.ENDC}")
40 | print(f"{bcolors.WARNING}Use 'CTRL+SHIFT+S' to save a checkpoint of the current epoch{bcolors.ENDC}")
41 | print(f"{bcolors.WARNING}Use 'CTRL+SHIFT+P' to generate samples for current epoch{bcolors.ENDC}")
42 | print(f"{bcolors.WARNING}Use 'CTRL+SHIFT+Q' to save and quit after the current epoch{bcolors.ENDC}")
43 | print(f"{bcolors.WARNING}Use 'CTRL+SHIFT+ALT+S' to save a checkpoint of the current step{bcolors.ENDC}")
44 | print(f"{bcolors.WARNING}Use 'CTRL+SHIFT+ALT+P' to generate samples for current step{bcolors.ENDC}")
45 | print(f"{bcolors.WARNING}Use 'CTRL+SHIFT+ALT+Q' to save and quit after the current step{bcolors.ENDC}")
46 | print('')
47 | print(f"{bcolors.WARNING}Use 'CTRL+H' to print this message again.{bcolors.ENDC}")
48 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
49 | if token is None:
50 | token = HfFolder.get_token()
51 | if organization is None:
52 | username = whoami(token)["name"]
53 | return f"{username}/{model_id}"
54 | else:
55 | return f"{organization}/{model_id}"
56 |
57 | #function to format a dictionary into a telegram message
58 | def format_dict(d):
59 | message = ""
60 | for key, value in d.items():
61 | #filter keys that have the word "token" in them
62 | if "token" in key and "tokenizer" not in key:
63 | value = "TOKEN"
64 | if 'id' in key:
65 | value = "ID"
66 | #if value is a dictionary, format it recursively
67 | if isinstance(value, dict):
68 | for k, v in value.items():
69 | message += f"\n- {k}: {v} \n"
70 | elif isinstance(value, list):
71 | #each value is a new line in the message
72 | message += f"- {key}:\n\n"
73 | for v in value:
74 | message += f" {v}\n\n"
75 | #if value is a list, format it as a list
76 | else:
77 | message += f"- {key}: {value}\n"
78 | return message
79 |
80 | def send_telegram_message(message, chat_id, token):
81 | url = f"https://api.telegram.org/bot{token}/sendMessage?chat_id={chat_id}&text={message}&parse_mode=html&disable_notification=True"
82 | import requests
83 | req = requests.get(url)
84 | if req.status_code != 200:
85 | raise ValueError(f"Telegram request failed with status code {req.status_code}")
86 | def send_media_group(chat_id,telegram_token, images, caption=None, reply_to_message_id=None):
87 | """
88 | Use this method to send an album of photos. On success, an array of Messages that were sent is returned.
89 | :param chat_id: chat id
90 | :param images: list of PIL images to send
91 | :param caption: caption of image
92 | :param reply_to_message_id: If the message is a reply, ID of the original message
93 | :return: response with the sent message
94 | """
95 | SEND_MEDIA_GROUP = f'https://api.telegram.org/bot{telegram_token}/sendMediaGroup'
96 | from io import BytesIO
97 | import requests
98 | files = {}
99 | media = []
100 | for i, img in enumerate(images):
101 | with BytesIO() as output:
102 | img.save(output, format='PNG')
103 | output.seek(0)
104 | name = f'photo{i}'
105 | files[name] = output.read()
106 | # a list of InputMediaPhoto. attach refers to the name of the file in the files dict
107 | media.append(dict(type='photo', media=f'attach://{name}'))
108 | media[0]['caption'] = caption
109 | media[0]['parse_mode'] = 'HTML'
110 | return requests.post(SEND_MEDIA_GROUP, data={'chat_id': chat_id, 'media': json.dumps(media),'disable_notification':True, 'reply_to_message_id': reply_to_message_id }, files=files)
111 | class AverageMeter:
112 | def __init__(self, name=None):
113 | self.name = name
114 | self.reset()
115 |
116 | def reset(self):
117 | self.sum = self.count = self.avg = 0
118 |
119 | def update(self, val, n=1):
120 | self.sum += val * n
121 | self.count += n
122 | self.avg = self.sum / self.count
123 |
124 | def exists(val):
125 | return val is not None
126 |
127 |
128 | def default(val, d):
129 | return val if exists(val) else d
130 |
131 |
132 | def masked_mse_loss(predicted, target, mask, reduction="none"):
133 | masked_predicted = predicted * mask
134 | masked_target = target * mask
135 | return F.mse_loss(masked_predicted, masked_target, reduction=reduction)
136 |
137 | # flash attention forwards and backwards
138 | # https://arxiv.org/abs/2205.14135
139 |
140 |
141 | class FlashAttentionFunction(torch.autograd.function.Function):
142 | @staticmethod
143 | @torch.no_grad()
144 | def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
145 | """ Algorithm 2 in the paper """
146 |
147 | device = q.device
148 | dtype = q.dtype
149 | max_neg_value = -torch.finfo(q.dtype).max
150 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
151 |
152 | o = torch.zeros_like(q)
153 | all_row_sums = torch.zeros(
154 | (*q.shape[:-1], 1), dtype=dtype, device=device)
155 | all_row_maxes = torch.full(
156 | (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
157 |
158 | scale = (q.shape[-1] ** -0.5)
159 |
160 | if not exists(mask):
161 | mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
162 | else:
163 | mask = rearrange(mask, 'b n -> b 1 1 n')
164 | mask = mask.split(q_bucket_size, dim=-1)
165 |
166 | row_splits = zip(
167 | q.split(q_bucket_size, dim=-2),
168 | o.split(q_bucket_size, dim=-2),
169 | mask,
170 | all_row_sums.split(q_bucket_size, dim=-2),
171 | all_row_maxes.split(q_bucket_size, dim=-2),
172 | )
173 |
174 | for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
175 | q_start_index = ind * q_bucket_size - qk_len_diff
176 |
177 | col_splits = zip(
178 | k.split(k_bucket_size, dim=-2),
179 | v.split(k_bucket_size, dim=-2),
180 | )
181 |
182 | for k_ind, (kc, vc) in enumerate(col_splits):
183 | k_start_index = k_ind * k_bucket_size
184 |
185 | attn_weights = einsum(
186 | '... i d, ... j d -> ... i j', qc, kc) * scale
187 |
188 | if exists(row_mask):
189 | attn_weights.masked_fill_(~row_mask, max_neg_value)
190 |
191 | if causal and q_start_index < (k_start_index + k_bucket_size - 1):
192 | causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
193 | device=device).triu(q_start_index - k_start_index + 1)
194 | attn_weights.masked_fill_(causal_mask, max_neg_value)
195 |
196 | block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
197 | attn_weights -= block_row_maxes
198 | exp_weights = torch.exp(attn_weights)
199 |
200 | if exists(row_mask):
201 | exp_weights.masked_fill_(~row_mask, 0.)
202 |
203 | block_row_sums = exp_weights.sum(
204 | dim=-1, keepdims=True).clamp(min=EPSILON)
205 |
206 | new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
207 |
208 | exp_values = einsum(
209 | '... i j, ... j d -> ... i d', exp_weights, vc)
210 |
211 | exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
212 | exp_block_row_max_diff = torch.exp(
213 | block_row_maxes - new_row_maxes)
214 |
215 | new_row_sums = exp_row_max_diff * row_sums + \
216 | exp_block_row_max_diff * block_row_sums
217 |
218 | oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
219 | (exp_block_row_max_diff / new_row_sums) * exp_values)
220 |
221 | row_maxes.copy_(new_row_maxes)
222 | row_sums.copy_(new_row_sums)
223 |
224 | ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
225 | ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
226 |
227 | return o
228 |
229 | @staticmethod
230 | @torch.no_grad()
231 | def backward(ctx, do):
232 | """ Algorithm 4 in the paper """
233 |
234 | causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
235 | q, k, v, o, l, m = ctx.saved_tensors
236 |
237 | device = q.device
238 |
239 | max_neg_value = -torch.finfo(q.dtype).max
240 | qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
241 |
242 | dq = torch.zeros_like(q)
243 | dk = torch.zeros_like(k)
244 | dv = torch.zeros_like(v)
245 |
246 | row_splits = zip(
247 | q.split(q_bucket_size, dim=-2),
248 | o.split(q_bucket_size, dim=-2),
249 | do.split(q_bucket_size, dim=-2),
250 | mask,
251 | l.split(q_bucket_size, dim=-2),
252 | m.split(q_bucket_size, dim=-2),
253 | dq.split(q_bucket_size, dim=-2)
254 | )
255 |
256 | for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
257 | q_start_index = ind * q_bucket_size - qk_len_diff
258 |
259 | col_splits = zip(
260 | k.split(k_bucket_size, dim=-2),
261 | v.split(k_bucket_size, dim=-2),
262 | dk.split(k_bucket_size, dim=-2),
263 | dv.split(k_bucket_size, dim=-2),
264 | )
265 |
266 | for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
267 | k_start_index = k_ind * k_bucket_size
268 |
269 | attn_weights = einsum(
270 | '... i d, ... j d -> ... i j', qc, kc) * scale
271 |
272 | if causal and q_start_index < (k_start_index + k_bucket_size - 1):
273 | causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
274 | device=device).triu(q_start_index - k_start_index + 1)
275 | attn_weights.masked_fill_(causal_mask, max_neg_value)
276 |
277 | exp_attn_weights = torch.exp(attn_weights - mc)
278 |
279 | if exists(row_mask):
280 | exp_attn_weights.masked_fill_(~row_mask, 0.)
281 |
282 | p = exp_attn_weights / lc
283 |
284 | dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
285 | dp = einsum('... i d, ... j d -> ... i j', doc, vc)
286 |
287 | D = (doc * oc).sum(dim=-1, keepdims=True)
288 | ds = p * scale * (dp - D)
289 |
290 | dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
291 | dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
292 |
293 | dqc.add_(dq_chunk)
294 | dkc.add_(dk_chunk)
295 | dvc.add_(dv_chunk)
296 |
297 | return dq, dk, dv, None, None, None, None
298 |
299 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
300 | text_encoder_config = PretrainedConfig.from_pretrained(
301 | pretrained_model_name_or_path,
302 | subfolder="text_encoder",
303 | revision=revision,
304 | )
305 | model_class = text_encoder_config.architectures[0]
306 |
307 | if model_class == "CLIPTextModel":
308 | from transformers import CLIPTextModel
309 |
310 | return CLIPTextModel
311 | elif model_class == "RobertaSeriesModelWithTransformation":
312 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
313 |
314 | return RobertaSeriesModelWithTransformation
315 | else:
316 | raise ValueError(f"{model_class} is not supported.")
317 |
318 | def replace_unet_cross_attn_to_flash_attention():
319 | print("Using FlashAttention")
320 |
321 | def forward_flash_attn(self, x, context=None, mask=None):
322 | q_bucket_size = 512
323 | k_bucket_size = 1024
324 |
325 | h = self.heads
326 | q = self.to_q(x)
327 |
328 | context = context if context is not None else x
329 | context = context.to(x.dtype)
330 |
331 | if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
332 | context_k, context_v = self.hypernetwork.forward(x, context)
333 | context_k = context_k.to(x.dtype)
334 | context_v = context_v.to(x.dtype)
335 | else:
336 | context_k = context
337 | context_v = context
338 |
339 | k = self.to_k(context_k)
340 | v = self.to_v(context_v)
341 | del context, x
342 |
343 | q, k, v = map(lambda t: rearrange(
344 | t, 'b n (h d) -> b h n d', h=h), (q, k, v))
345 |
346 | out = FlashAttentionFunction.apply(q, k, v, mask, False,
347 | q_bucket_size, k_bucket_size)
348 |
349 | out = rearrange(out, 'b h n d -> b n (h d)')
350 |
351 | # diffusers 0.6.0
352 | if type(self.to_out) is torch.nn.Sequential:
353 | return self.to_out(out)
354 |
355 | # diffusers 0.7.0
356 | out = self.to_out[0](out)
357 | out = self.to_out[1](out)
358 | return out
359 |
360 | diffusers.models.attention.CrossAttention.forward = forward_flash_attn
361 | class Depth2Img:
362 | def __init__(self,unet,text_encoder,revision,pretrained_model_name_or_path,accelerator):
363 | self.unet = unet
364 | self.text_encoder = text_encoder
365 | self.revision = revision if revision != 'no' else 'fp32'
366 | self.pretrained_model_name_or_path = pretrained_model_name_or_path
367 | self.accelerator = accelerator
368 | self.pipeline = None
369 | def depth_images(self,paths):
370 | if self.pipeline is None:
371 | self.pipeline = DiffusionPipeline.from_pretrained(
372 | self.pretrained_model_name_or_path,
373 | unet=self.accelerator.unwrap_model(self.unet),
374 | text_encoder=self.accelerator.unwrap_model(self.text_encoder),
375 | revision=self.revision,
376 | local_files_only=True,)
377 | self.pipeline.to(self.accelerator.device)
378 | self.vae_scale_factor = 2 ** (len(self.pipeline.vae.config.block_out_channels) - 1)
379 | non_depth_image_files = []
380 | image_paths_by_path = {}
381 |
382 | for path in paths:
383 | #if path is list
384 | if isinstance(path, list):
385 | img = Path(path[0])
386 | else:
387 | img = Path(path)
388 | if self.get_depth_image_path(img).exists():
389 | continue
390 | else:
391 | non_depth_image_files.append(img)
392 | image_objects = []
393 | for image_path in non_depth_image_files:
394 | image_instance = Image.open(image_path)
395 | if not image_instance.mode == "RGB":
396 | image_instance = image_instance.convert("RGB")
397 | image_instance = self.pipeline.feature_extractor(
398 | image_instance, return_tensors="pt"
399 | ).pixel_values
400 |
401 | image_instance = image_instance.to(self.accelerator.device)
402 | image_objects.append((image_path, image_instance))
403 |
404 | for image_path, image_instance in image_objects:
405 | path = image_path.parent
406 | ogImg = Image.open(image_path)
407 | ogImg_x = ogImg.size[0]
408 | ogImg_y = ogImg.size[1]
409 | depth_map = self.pipeline.depth_estimator(image_instance).predicted_depth
410 | depth_min = torch.amin(depth_map, dim=[0, 1, 2], keepdim=True)
411 | depth_max = torch.amax(depth_map, dim=[0, 1, 2], keepdim=True)
412 | depth_map = torch.nn.functional.interpolate(depth_map.unsqueeze(1),size=(ogImg_y, ogImg_x),mode="bicubic",align_corners=False,)
413 |
414 | depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
415 | depth_map = depth_map[0,:,:]
416 | depth_map_image = transforms.ToPILImage()(depth_map)
417 | depth_map_image = depth_map_image.filter(ImageFilter.GaussianBlur(radius=1))
418 | depth_map_image.save(self.get_depth_image_path(image_path))
419 | #quit()
420 | return 2 ** (len(self.pipeline.vae.config.block_out_channels) - 1)
421 |
422 | def get_depth_image_path(self,image_path):
423 | #if image_path is a string, convert it to a Path object
424 | if isinstance(image_path, str):
425 | image_path = Path(image_path)
426 | return image_path.parent / f"{image_path.stem}-depth.png"
427 |
428 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 and taken from harubaru's implementation https://github.com/harubaru/waifu-diffusion
429 | class EMAModel:
430 | """
431 | Exponential Moving Average of models weights
432 | """
433 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
434 | parameters = list(parameters)
435 | self.shadow_params = [p.clone().detach() for p in parameters]
436 |
437 | self.decay = decay
438 | self.optimization_step = 0
439 |
440 | def get_decay(self, optimization_step):
441 | """
442 | Compute the decay factor for the exponential moving average.
443 | """
444 | value = (1 + optimization_step) / (10 + optimization_step)
445 | return 1 - min(self.decay, value)
446 |
447 | @torch.no_grad()
448 | def step(self, parameters):
449 | parameters = list(parameters)
450 |
451 | self.optimization_step += 1
452 | self.decay = self.get_decay(self.optimization_step)
453 |
454 | for s_param, param in zip(self.shadow_params, parameters):
455 | if param.requires_grad:
456 | tmp = self.decay * (s_param - param)
457 | s_param.sub_(tmp)
458 | else:
459 | s_param.copy_(param)
460 |
461 | torch.cuda.empty_cache()
462 |
463 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
464 | """
465 | Copy current averaged parameters into given collection of parameters.
466 | Args:
467 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
468 | updated with the stored moving averages. If `None`, the
469 | parameters with which this `ExponentialMovingAverage` was
470 | initialized will be used.
471 | """
472 | parameters = list(parameters)
473 | for s_param, param in zip(self.shadow_params, parameters):
474 | param.data.copy_(s_param.data)
475 |
476 | def to(self, device=None, dtype=None) -> None:
477 | r"""Move internal buffers of the ExponentialMovingAverage to `device`.
478 | Args:
479 | device: like `device` argument to `torch.Tensor.to`
480 | """
481 | # .to() on the tensors handles None correctly
482 | self.shadow_params = [
483 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
484 | for p in self.shadow_params
485 | ]
--------------------------------------------------------------------------------
/scripts/windows_install.py:
--------------------------------------------------------------------------------
1 | import filecmp
2 | import importlib.util
3 | import os
4 | import shutil
5 | import sys
6 | import sysconfig
7 | import subprocess
8 | from pathlib import Path
9 | import requests
10 | import zipfile
11 | if sys.version_info < (3, 8):
12 | import importlib_metadata
13 | else:
14 | import importlib.metadata as importlib_metadata
15 |
16 | req_file = os.path.join(os.getcwd(), "requirements.txt")
17 |
18 | def run(command, desc=None, errdesc=None, custom_env=None):
19 | if desc is not None:
20 | print(desc)
21 |
22 | result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
23 |
24 | if result.returncode != 0:
25 |
26 | message = f"""{errdesc or 'Error running command'}.
27 | Command: {command}
28 | Error code: {result.returncode}
29 | stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else ''}
30 | stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else ''}
31 | """
32 | raise RuntimeError(message)
33 |
34 | return result.stdout.decode(encoding="utf8", errors="ignore")
35 |
36 | def check_versions():
37 | global req_file
38 | reqs = open(req_file, 'r')
39 | lines = reqs.readlines()
40 | reqs_dict = {}
41 | for line in lines:
42 | splits = line.split("==")
43 | if len(splits) == 2:
44 | key = splits[0]
45 | if "torch" not in key:
46 | if "diffusers" in key:
47 | key = "diffusers"
48 | reqs_dict[key] = splits[1].replace("\n", "").strip()
49 |
50 | if os.name == "nt":
51 | reqs_dict["torch"] = "1.12.1+cu116"
52 | reqs_dict["torchvision"] = "0.13.1+cu116"
53 |
54 | checks = ["xformers","bitsandbytes", "diffusers", "transformers", "torch", "torchvision"]
55 | for check in checks:
56 | check_ver = "N/A"
57 | status = "[ ]"
58 | try:
59 | check_available = importlib.util.find_spec(check) is not None
60 | if check_available:
61 | check_ver = importlib_metadata.version(check)
62 | if check in reqs_dict:
63 | req_version = reqs_dict[check]
64 | if str(check_ver) == str(req_version):
65 | status = "[+]"
66 | else:
67 | status = "[!]"
68 | except importlib_metadata.PackageNotFoundError:
69 | check_available = False
70 | if not check_available:
71 | status = "[!]"
72 | print(f"{status} {check} NOT installed.")
73 | if check == 'xformers':
74 | #if windows, install xformers from prebuilt wheel
75 | x_cmd = r"https://huggingface.co/r4ziel/xformers_pre_built/resolve/main/xformers-0.0.17.dev464-cp310-cp310-win_amd64.whl"
76 | print(f"Installing xformers with: pip install {x_cmd}")
77 | run(f"pip install {x_cmd}", desc="Installing xformers")
78 |
79 | else:
80 | print(f"{status} {check} version {check_ver} installed.")
81 |
82 |
83 | dreambooth_skip_install = os.environ.get('DREAMBOOTH_SKIP_INSTALL', False)
84 |
85 | if not dreambooth_skip_install:
86 | check_versions()
87 | name = "StableTuner"
88 | run(f'"{sys.executable}" -m pip install -r "{req_file}"', f"Checking {name} requirements...",
89 | f"Couldn't install {name} requirements.")
90 |
91 | # I think we only need to bump torch version to cu116 on Windows, as we're using prebuilt B&B Binaries...
92 | if os.name == "nt":
93 | torch_cmd = os.environ.get('TORCH_COMMAND', None)
94 | if torch_cmd is None:
95 | torch_cmd = 'pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 --upgrade"'
96 |
97 | run(f'"{sys.executable}" -m {torch_cmd}', "Checking/upgrading existing torch/torchvision installation", "Couldn't install torch")
98 |
99 |
100 | #if .cache directory in Path.home() exists
101 | hf_cache_dir = Path.home() / ".cache"
102 | if hf_cache_dir.exists():
103 | #check if huggingface exists
104 | hf_dir = hf_cache_dir / "huggingface"
105 | if hf_dir.exists():
106 | #check if accelerate exists
107 | accelerate_dir = hf_dir / "accelerate"
108 | if accelerate_dir.exists():
109 | #print('test')
110 | src_file = 'resources/accelerate_windows/accelerate_default_config.yaml'
111 | dst_file = 'default_config.yaml'
112 | #load from cwd
113 | src = Path.cwd() / src_file
114 | dst = accelerate_dir / dst_file
115 | print(src)
116 | if src.exists():
117 | shutil.copy2(src, dst)
118 | print(f"Updated {dst_file} in {accelerate_dir}")
119 | else:
120 | #make dirs
121 | hf_cache_dir.mkdir(parents=True, exist_ok=True)
122 | hf_dir = hf_cache_dir / "huggingface"
123 | hf_dir.mkdir(parents=True, exist_ok=True)
124 | accelerate_dir = hf_dir / "accelerate"
125 | accelerate_dir.mkdir(parents=True, exist_ok=True)
126 | src_file = 'accelerate_default_config.json'
127 | dst_file = 'default_config.json'
128 | src = Path.cwd() / src_file
129 | dst = accelerate_dir / dst_file
130 | if src.exists():
131 | if dst.exists():
132 | shutil.copy2(src, dst)
133 | print(f"Created {dst_file} in {accelerate_dir}")
134 |
135 |
136 |
137 | base_dir = os.path.dirname(os.getcwd())
138 | #repo = git.Repo(base_dir)
139 | #revision = repo.rev_parse("HEAD")
140 | #print(f"Dreambooth revision is {revision}")
141 | check_versions()
142 | # Check for "different" B&B Files and copy only if necessary
143 | if os.name == "nt":
144 | python = sys.executable
145 | run(f'"{python}" -m pip install https://huggingface.co/r4ziel/xformers_pre_built/resolve/main/triton-2.0.0-cp310-cp310-win_amd64.whl', "Installing Triton", "Couldn't install triton")
146 | bnb_src = os.path.join(os.getcwd(), "resources/bitsandbytes_windows")
147 | bnb_dest = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes")
148 | cudnn_src = os.path.join(os.getcwd(), "resources/cudnn_windows")
149 | #check if chudnn is in cwd
150 | if not os.path.exists(cudnn_src):
151 | print("Can't find CUDNN in resources, trying main folder...")
152 | cudnn_src = os.path.join(os.getcwd(), "cudnn_windows")
153 | if not os.path.exists(cudnn_src):
154 | cudnn_url = "https://b1.thefileditch.ch/mwxKTEtelILoIbMbruuM.zip"
155 | print(f"Downloading CUDNN 8.6")
156 | #download with requests
157 | r = requests.get(cudnn_url, allow_redirects=True)
158 | #save to cwd
159 | open('cudnn_windows.zip', 'wb').write(r.content)
160 | #unzip
161 | with zipfile.ZipFile('cudnn_windows.zip','r') as zip_ref:
162 | zip_ref.extractall(os.path.join(os.getcwd(),"resources/cudnn_windows"))
163 | #remove zip
164 | os.remove('cudnn_windows.zip')
165 | cudnn_src = os.path.join(os.getcwd(), "resources/cudnn_windows")
166 |
167 | cudnn_dest = os.path.join(sysconfig.get_paths()["purelib"], "torch", "lib")
168 | print(f"Checking for B&B files in {bnb_dest}")
169 | if not os.path.exists(bnb_dest):
170 | # make destination directory
171 | os.makedirs(bnb_dest, exist_ok=True)
172 | printed = False
173 | filecmp.clear_cache()
174 | for file in os.listdir(bnb_src):
175 | src_file = os.path.join(bnb_src, file)
176 | if file == "main.py":
177 | dest = os.path.join(bnb_dest, "cuda_setup")
178 | if not os.path.exists(dest):
179 | os.mkdir(dest)
180 | else:
181 | dest = bnb_dest
182 | if not os.path.exists(dest):
183 | os.mkdir(dest)
184 | dest_file = os.path.join(dest, file)
185 | status = shutil.copy2(src_file, dest)
186 | if status:
187 | print("Copied B&B files to destination")
188 | print(f"Checking for CUDNN files in {cudnn_dest}")
189 | if os.path.exists(cudnn_src):
190 | if os.path.exists(cudnn_dest):
191 | # check for different files
192 | filecmp.clear_cache()
193 | for file in os.listdir(cudnn_src):
194 | src_file = os.path.join(cudnn_src, file)
195 | dest_file = os.path.join(cudnn_dest, file)
196 | #if dest file exists, check if it's different
197 | if os.path.exists(dest_file):
198 | status = shutil.copy2(src_file, cudnn_dest)
199 | if status:
200 | print("Copied CUDNN 8.6 files to destination")
201 | d_commit = '8178c84'
202 | diffusers_cmd = f"git+https://github.com/huggingface/diffusers.git@{d_commit}#egg=diffusers --force-reinstall"
203 | run(f'"{python}" -m pip install {diffusers_cmd}', f"Installing Diffusers {d_commit} commit", "Couldn't install diffusers")
204 | #install requirements file
205 | t_commit = 'cc84075'
206 | trasn_cmd = f"git+https://github.com/huggingface/transformers.git@{t_commit}#egg=transformers --force-reinstall"
207 | run(f'"{python}" -m pip install {trasn_cmd}', f"Installing Transformers {t_commit} commit", "Couldn't install transformers")
208 |
209 | req_file = os.path.join(os.getcwd(), "requirements.txt")
210 | run(f'"{python}" -m pip install -r "{req_file}"', "Updating requirements", "Couldn't install requirements")
211 |
--------------------------------------------------------------------------------