├── sd_batch_runner ├── __init__.py ├── sdwebui_temp_fix.py ├── tagger.py ├── lora.py ├── util.py └── generate.py ├── launch_GUI.bat ├── launch_cmd.bat ├── img ├── generate.png ├── tagging.png └── gui_sample.png ├── requirements.txt ├── .gitignore ├── default_config ├── default_lora_dir.json ├── default_preset_tags.json ├── default_config.json └── default_controlnet.json ├── main.py └── README.md /sd_batch_runner/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /launch_GUI.bat: -------------------------------------------------------------------------------- 1 | call "venv/Scripts/activate.bat" 2 | flet gui_main.py 3 | -------------------------------------------------------------------------------- /launch_cmd.bat: -------------------------------------------------------------------------------- 1 | %windir%\System32\cmd.exe /K "venv\Scripts\activate.bat" 2 | 3 | -------------------------------------------------------------------------------- /img/generate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/sd_batch_runner/main/img/generate.png -------------------------------------------------------------------------------- /img/tagging.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/sd_batch_runner/main/img/tagging.png -------------------------------------------------------------------------------- /img/gui_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s9roll7/sd_batch_runner/main/img/gui_sample.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pillow 3 | fire 4 | opencv-python 5 | webuiapi==0.9.15 6 | av 7 | flet==0.25.1 8 | pandas 9 | huggingface_hub 10 | tqdm 11 | timm 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | /venv 3 | /tmp 4 | /input 5 | /output 6 | .vscode 7 | /lora_dir_env 8 | /config.json 9 | /controlnet.json 10 | /lora_dir.json 11 | /preset_tags.json 12 | -------------------------------------------------------------------------------- /default_config/default_lora_dir.json: -------------------------------------------------------------------------------- 1 | { 2 | "pony": { 3 | "character_dir_path": "enter your character lora directory. (ex. YOUR_SD_ENV_PATH/models/Lora/pony/character", 4 | "style_dir_path": "YOUR_SD_ENV_PATH/models/Lora/pony/style", 5 | "pose_dir_path": "YOUR_SD_ENV_PATH/models/Lora/pony/pose", 6 | "item_dir_path": "YOUR_SD_ENV_PATH/models/Lora/pony/item" 7 | }, 8 | "sd15": { 9 | "character_dir_path": "enter your character lora directory. (ex. YOUR_SD_ENV_PATH/models/Lora/sd15/character", 10 | "style_dir_path": "YOUR_SD_ENV_PATH/models/Lora/sd15/style", 11 | "pose_dir_path": "YOUR_SD_ENV_PATH/models/Lora/sd15/pose", 12 | "item_dir_path": "YOUR_SD_ENV_PATH/models/Lora/sd15/item" 13 | } 14 | } -------------------------------------------------------------------------------- /default_config/default_preset_tags.json: -------------------------------------------------------------------------------- 1 | { 2 | "pony_quality": { 3 | "prompt": "score_9, score_8_up, score_8, score_7_up", 4 | "negative_prompt": "score_6, score_5, score_4, (watermark, deformed, blurry, censored, low quality, text, worst quality, extra hands)", 5 | "is_footer": false 6 | }, 7 | "booru_quality": { 8 | "prompt": "masterpiece, best quality, ultra-detailed, very aesthetic, illustration, disheveled hair, perfect composition, moist skin, intricate details", 9 | "negative_prompt": "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, very displeasing", 10 | "is_footer": false 11 | }, 12 | "perfect_face": { 13 | "prompt": "perfect face, perfect eyes", 14 | "negative_prompt": "", 15 | "is_footer": false 16 | }, 17 | "anime_style": { 18 | "prompt": "source_anime", 19 | "negative_prompt": "3D, Photorealistic, Lifelike, Realistic, True to Life, Vivid, Picture-Perfect, Naturalistic, Convincing", 20 | "is_footer": false 21 | }, 22 | "photo_style": { 23 | "prompt": "source_real, raw, photo, very detailed, realistic, highly detailed, high detail, soft lighting, dramatic shadows, highly detailed, ((detailed skin)) ,depth of field, ((film grain))", 24 | "negative_prompt": "", 25 | "is_footer": false 26 | }, 27 | "detail_up": { 28 | "prompt": "", 29 | "negative_prompt": "", 30 | "is_footer": true 31 | } 32 | } -------------------------------------------------------------------------------- /sd_batch_runner/sdwebui_temp_fix.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from pathlib import Path 5 | from datetime import datetime 6 | import math 7 | import random 8 | import re 9 | 10 | import webuiapi 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.INFO) 14 | formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)s %(funcName)s [%(levelname)s]: %(message)s") 15 | handler = logging.StreamHandler() 16 | handler.setFormatter(formatter) 17 | logger.addHandler(handler) 18 | 19 | ######################################################### 20 | # sdwebui temp fix 21 | 22 | class ControlNetUnit2(webuiapi.ControlNetUnit): 23 | 24 | def to_dict(self): 25 | if not hasattr(self, 'effective_region_mask'): 26 | self.effective_region_mask = None 27 | 28 | if self.image is None and self.mask is None: 29 | return { 30 | "module": self.module, 31 | "model": None if self.model=="none" else self.model, 32 | "weight": self.weight, 33 | "resize_mode": self.resize_mode, 34 | "low_vram": self.low_vram, 35 | "processor_res": self.processor_res, 36 | "threshold_a": self.threshold_a, 37 | "threshold_b": self.threshold_b, 38 | "guidance_start": self.guidance_start, 39 | "guidance_end": self.guidance_end, 40 | "control_mode": self.control_mode, 41 | "pixel_perfect": self.pixel_perfect, 42 | "hr_option": self.hr_option, 43 | "enabled": self.enabled, 44 | } 45 | else: 46 | return { 47 | "image": webuiapi.raw_b64_img(self.image) if self.image else "", 48 | "mask": webuiapi.raw_b64_img(self.mask) if self.mask is not None else None, 49 | "effective_region_mask": webuiapi.raw_b64_img(self.effective_region_mask) if self.effective_region_mask is not None else None, 50 | "module": self.module, 51 | "model": None if self.model=="none" else self.model, 52 | "weight": self.weight, 53 | "resize_mode": self.resize_mode, 54 | "low_vram": self.low_vram, 55 | "processor_res": self.processor_res, 56 | "threshold_a": self.threshold_a, 57 | "threshold_b": self.threshold_b, 58 | "guidance_start": self.guidance_start, 59 | "guidance_end": self.guidance_end, 60 | "control_mode": self.control_mode, 61 | "pixel_perfect": self.pixel_perfect, 62 | "hr_option": self.hr_option, 63 | "enabled": self.enabled, 64 | } 65 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from pathlib import Path 4 | import fire 5 | 6 | from sd_batch_runner.util import * 7 | from sd_batch_runner.lora import update_lora_command,show_lora_command,show_lora_env_command,set_lora_env_command 8 | from sd_batch_runner.generate import one_command,generate_command,show_checkpoint_command,set_default_checkpoint_command,show_controlnet_command 9 | 10 | logger = logging.getLogger(__name__) 11 | logger.setLevel(logging.INFO) 12 | formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)s %(funcName)s [%(levelname)s]: %(message)s") 13 | handler = logging.StreamHandler() 14 | handler.setFormatter(formatter) 15 | logger.addHandler(handler) 16 | 17 | 18 | 19 | ############################################################### 20 | 21 | class Command: 22 | def __init__(self): 23 | config_restore_files_if_needed() 24 | 25 | def update_lora(self, is_overwrite=False): 26 | start_tim = time.time() 27 | 28 | update_lora_command(is_overwrite) 29 | 30 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 31 | 32 | def one(self,char="@random",style="@random",pose="@random",item=None,header="score_9, score_8_up, score_7_up, score_6_up, score_5_up, score_4_up, masterpiece, perfect face, perfect eyes",footer="zPDXL3",n=1): 33 | start_tim = time.time() 34 | 35 | one_command(char,style,pose,item,header,footer,n) 36 | 37 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 38 | 39 | def generate(self, json_path, n=1): 40 | start_tim = time.time() 41 | 42 | generate_command(Path(json_path),n) 43 | 44 | clear_video_cache() 45 | 46 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 47 | 48 | def show_checkpoint(self): 49 | start_tim = time.time() 50 | 51 | show_checkpoint_command() 52 | 53 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 54 | 55 | def set_default_checkpoint(self, checkpoint_number): 56 | start_tim = time.time() 57 | 58 | set_default_checkpoint_command(checkpoint_number) 59 | 60 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 61 | 62 | def show_controlnet(self): 63 | start_tim = time.time() 64 | 65 | show_controlnet_command() 66 | 67 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 68 | 69 | def show_lora(self): 70 | start_tim = time.time() 71 | 72 | show_lora_command() 73 | 74 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 75 | 76 | def show_lora_env(self): 77 | start_tim = time.time() 78 | 79 | show_lora_env_command() 80 | 81 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 82 | 83 | def set_lora_env(self, new_env): 84 | start_tim = time.time() 85 | 86 | set_lora_env_command(new_env) 87 | 88 | logger.info(f"Total Elapsed time : {time.time() - start_tim}") 89 | 90 | fire.Fire(Command) 91 | -------------------------------------------------------------------------------- /default_config/default_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "lora_generate_tag": { 3 | "enable_character": true, 4 | "enable_style": true, 5 | "enable_pose": true, 6 | "enable_item": true, 7 | "tag_th_character": 0.5, 8 | "tag_th_style": 0.5, 9 | "tag_th_pose": 0.5, 10 | "tag_th_item": 0.5, 11 | "prohibited_tags_character": [ 12 | "solo", 13 | "simple background" 14 | ], 15 | "prohibited_tags_style": [ 16 | "solo", 17 | "simple background" 18 | ], 19 | "prohibited_tags_pose": [ 20 | "simple background" 21 | ], 22 | "prohibited_tags_item": [ 23 | "solo", 24 | "simple background" 25 | ] 26 | }, 27 | "generation_setting_txt2img": { 28 | "enable_hr": false, 29 | "denoising_strength": 0.7, 30 | "firstphase_width": 0, 31 | "firstphase_height": 0, 32 | "hr_scale": 2, 33 | "hr_upscaler": "Latent", 34 | "hr_second_pass_steps": 0, 35 | "hr_resize_x": 0, 36 | "hr_resize_y": 0 37 | }, 38 | "generation_setting_img2img": { 39 | "resize_mode": 0, 40 | "denoising_strength": 0.75, 41 | "image_cfg_scale": 1.5, 42 | "mask_blur": 4, 43 | "inpainting_fill": 1, 44 | "inpaint_full_res": true, 45 | "inpaint_full_res_padding": 0, 46 | "inpainting_mask_invert": 0, 47 | "initial_noise_multiplier": 1 48 | }, 49 | "generation_setting_common": { 50 | "sampler_name": "Euler a", 51 | "scheduler": "automatic", 52 | "steps": 25, 53 | "cfg_scale": 7.0, 54 | "width": 832, 55 | "height": 1216, 56 | "restore_faces": false, 57 | "tiling": false, 58 | "do_not_save_samples": false, 59 | "do_not_save_grid": false, 60 | "negative_prompt": "", 61 | "eta": 1.0, 62 | "send_images": true, 63 | "save_images": false 64 | }, 65 | "prompt_gen_setting": { 66 | "model_name": "AUTOMATIC/promptgen-majinai-safe", 67 | "text": "", 68 | "min_length": 20, 69 | "max_length": 20, 70 | "num_beams": 1, 71 | "temperature": 1, 72 | "repetition_penalty": 1, 73 | "length_preference": 1, 74 | "sampling_mode": "Top K", 75 | "top_k": 12, 76 | "top_p": 0.15 77 | }, 78 | "overwrite_generation_setting": { 79 | "overwrite_steps": true, 80 | "overwrite_sampler_name": true, 81 | "overwrite_scheduler": true, 82 | "overwrite_cfg_scale": true, 83 | "overwrite_width": true, 84 | "overwrite_height": true, 85 | "overwrite_prompt": true, 86 | "overwrite_negative_prompt": true, 87 | "overwrite_seed": true, 88 | "add_lora": false, 89 | "add_prompt_gen": false 90 | }, 91 | "segment_anything": { 92 | "sam_model_name": "sam_hq_vit_h.pth", 93 | "dino_model_name": "GroundingDINO_SwinT_OGC (694MB)" 94 | }, 95 | "default_checkpoint": "ponyDiffusionV6XL_v6StartWithThisOne.safetensors [67ab2fd8ec]", 96 | "lora_dir_env": "pony", 97 | "lora_block_weight": { 98 | "character": { 99 | "enable_lbw": true, 100 | "preset": "Char", 101 | "start_stop_step": "stop", 102 | "start_stop_step_value": 10 103 | }, 104 | "character2": { 105 | "enable_lbw": true, 106 | "preset": "Char", 107 | "start_stop_step": "stop", 108 | "start_stop_step_value": 10 109 | }, 110 | "style": { 111 | "enable_lbw": false, 112 | "preset": "ArtStyle", 113 | "start_stop_step": "none", 114 | "start_stop_step_value": 10 115 | }, 116 | "style2": { 117 | "enable_lbw": false, 118 | "preset": "ArtStyle", 119 | "start_stop_step": "none", 120 | "start_stop_step_value": 10 121 | }, 122 | "pose": { 123 | "enable_lbw": true, 124 | "preset": "Pose", 125 | "start_stop_step": "stop", 126 | "start_stop_step_value": 10 127 | }, 128 | "pose2": { 129 | "enable_lbw": true, 130 | "preset": "Pose", 131 | "start_stop_step": "stop", 132 | "start_stop_step_value": 10 133 | }, 134 | "item": { 135 | "enable_lbw": true, 136 | "preset": "Pose", 137 | "start_stop_step": "stop", 138 | "start_stop_step_value": 10 139 | }, 140 | "item2": { 141 | "enable_lbw": true, 142 | "preset": "Pose", 143 | "start_stop_step": "stop", 144 | "start_stop_step_value": 10 145 | } 146 | }, 147 | "adetailer": [ 148 | { 149 | "ad_model": "face_yolov8n.pt", 150 | "ad_prompt": "", 151 | "ad_negative_prompt": "" 152 | } 153 | ] 154 | } -------------------------------------------------------------------------------- /default_config/default_controlnet.json: -------------------------------------------------------------------------------- 1 | { 2 | "open_pose": { 3 | "module" : "openpose_full", 4 | "model" : "controlnet-openpose-sdxl [d0333a45]", 5 | "weight" : 0.5, 6 | "resize_mode" : "Resize and Fill", 7 | "low_vram" : false, 8 | "processor_res" : 512, 9 | "threshold_a" : 64, 10 | "threshold_b" : 64, 11 | "guidance_start" : 0.0, 12 | "guidance_end" : 0.5, 13 | "control_mode" : 0, 14 | "pixel_perfect" : true, 15 | "hr_option" : "Both" 16 | }, 17 | "open_pose_n": { 18 | "module" : "none", 19 | "model" : "controlnet-openpose-sdxl [d0333a45]", 20 | "weight" : 0.5, 21 | "resize_mode" : "Resize and Fill", 22 | "low_vram" : false, 23 | "processor_res" : 512, 24 | "threshold_a" : 64, 25 | "threshold_b" : 64, 26 | "guidance_start" : 0.0, 27 | "guidance_end" : 0.5, 28 | "control_mode" : 0, 29 | "pixel_perfect" : true, 30 | "hr_option" : "Both" 31 | }, 32 | "depth": { 33 | "module" : "depth_anything_v2", 34 | "model" : "controlnet-depth-sdxl [e590f04c]", 35 | "weight" : 0.25, 36 | "resize_mode" : "Resize and Fill", 37 | "low_vram" : false, 38 | "processor_res" : 512, 39 | "threshold_a" : 64, 40 | "threshold_b" : 64, 41 | "guidance_start" : 0.0, 42 | "guidance_end" : 0.5, 43 | "control_mode" : 0, 44 | "pixel_perfect" : true, 45 | "hr_option" : "Both" 46 | }, 47 | "depth_n": { 48 | "module" : "none", 49 | "model" : "controlnet-depth-sdxl [e590f04c]", 50 | "weight" : 0.25, 51 | "resize_mode" : "Resize and Fill", 52 | "low_vram" : false, 53 | "processor_res" : 512, 54 | "threshold_a" : 64, 55 | "threshold_b" : 64, 56 | "guidance_start" : 0.0, 57 | "guidance_end" : 0.5, 58 | "control_mode" : 0, 59 | "pixel_perfect" : true, 60 | "hr_option" : "Both" 61 | }, 62 | "inpaint": { 63 | "module" : "inpaint_global_harmonious", 64 | "model" : "controlnet++_union_sdxl [15e6ad5d]", 65 | "weight" : 1.0, 66 | "resize_mode" : "Resize and Fill", 67 | "low_vram" : false, 68 | "processor_res" : 1.0, 69 | "threshold_a" : 0.5, 70 | "threshold_b" : 0.5, 71 | "guidance_start" : 0.0, 72 | "guidance_end" : 1.0, 73 | "control_mode" : 0, 74 | "pixel_perfect" : true, 75 | "hr_option" : "Both" 76 | }, 77 | "ref_only": { 78 | "module" : "reference_only", 79 | "model" : "none", 80 | "weight" : 1.0, 81 | "resize_mode" : "Resize and Fill", 82 | "low_vram" : false, 83 | "processor_res" : 512, 84 | "threshold_a" : 64, 85 | "threshold_b" : 64, 86 | "guidance_start" : 0.0, 87 | "guidance_end" : 1.0, 88 | "control_mode" : 0, 89 | "pixel_perfect" : true, 90 | "hr_option" : "Both" 91 | }, 92 | "ip_adapter": { 93 | "module" : "ip-adapter_clip_sdxl_plus_vith", 94 | "model" : "ipAdapterModelsForSDXL_ipAdapterSDXLVitH [75a08f84]", 95 | "weight" : 1.0, 96 | "resize_mode" : "Resize and Fill", 97 | "low_vram" : false, 98 | "processor_res" : 512, 99 | "threshold_a" : 64, 100 | "threshold_b" : 64, 101 | "guidance_start" : 0.0, 102 | "guidance_end" : 1.0, 103 | "control_mode" : 0, 104 | "pixel_perfect" : true, 105 | "hr_option" : "Both" 106 | }, 107 | "tile_real": { 108 | "module" : "tile_resample", 109 | "model" : "TTPLanet_SDXL_Controlnet_Tile_Realistic_v20Fp16 [c32b8550]", 110 | "weight" : 1.0, 111 | "resize_mode" : "Resize and Fill", 112 | "low_vram" : false, 113 | "processor_res" : 512, 114 | "threshold_a" : 64, 115 | "threshold_b" : 64, 116 | "guidance_start" : 0.0, 117 | "guidance_end" : 1.0, 118 | "control_mode" : 1, 119 | "pixel_perfect" : true, 120 | "hr_option" : "Both" 121 | }, 122 | "tile_real_blur": { 123 | "module" : "blur_gaussian", 124 | "model" : "TTPLanet_SDXL_Controlnet_Tile_Realistic_v20Fp16 [c32b8550]", 125 | "weight" : 1.0, 126 | "resize_mode" : "Resize and Fill", 127 | "low_vram" : false, 128 | "processor_res" : 512, 129 | "threshold_a" : 64, 130 | "threshold_b" : 64, 131 | "guidance_start" : 0.0, 132 | "guidance_end" : 1.0, 133 | "control_mode" : 1, 134 | "pixel_perfect" : true, 135 | "hr_option" : "Both" 136 | }, 137 | "tile": { 138 | "module" : "tile_resample", 139 | "model" : "controlnet-tile-sdxl [4d6257d3]", 140 | "weight" : 1.0, 141 | "resize_mode" : "Resize and Fill", 142 | "low_vram" : false, 143 | "processor_res" : 512, 144 | "threshold_a" : 64, 145 | "threshold_b" : 64, 146 | "guidance_start" : 0.0, 147 | "guidance_end" : 1.0, 148 | "control_mode" : 1, 149 | "pixel_perfect" : true, 150 | "hr_option" : "Both" 151 | }, 152 | "tile_blur": { 153 | "module" : "blur_gaussian", 154 | "model" : "controlnet-tile-sdxl [4d6257d3]", 155 | "weight" : 1.0, 156 | "resize_mode" : "Resize and Fill", 157 | "low_vram" : false, 158 | "processor_res" : 512, 159 | "threshold_a" : 64, 160 | "threshold_b" : 64, 161 | "guidance_start" : 0.0, 162 | "guidance_end" : 1.0, 163 | "control_mode" : 1, 164 | "pixel_perfect" : true, 165 | "hr_option" : "Both" 166 | } 167 | } -------------------------------------------------------------------------------- /sd_batch_runner/tagger.py: -------------------------------------------------------------------------------- 1 | # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py 2 | 3 | import re 4 | import logging 5 | import os 6 | 7 | import cv2 8 | import numpy as np 9 | from pathlib import Path 10 | import torch 11 | import timm 12 | import pandas as pd 13 | from PIL import Image 14 | from tqdm import tqdm 15 | 16 | from sd_batch_runner.util import get_image_file_list 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | logger.setLevel(logging.INFO) 21 | formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)s %(funcName)s [%(levelname)s]: %(message)s") 22 | handler = logging.StreamHandler() 23 | handler.setFormatter(formatter) 24 | logger.addHandler(handler) 25 | 26 | 27 | 28 | def prepare_wd14tagger(): 29 | import os 30 | from pathlib import PurePosixPath 31 | 32 | from huggingface_hub import hf_hub_download 33 | 34 | os.makedirs("data/models/WD14tagger", exist_ok=True) 35 | for hub_file in [ 36 | "selected_tags.csv", 37 | ]: 38 | path = Path(hub_file) 39 | 40 | saved_path = "data/models/WD14tagger" / path 41 | 42 | if os.path.exists(saved_path): 43 | continue 44 | 45 | hf_hub_download( 46 | repo_id="SmilingWolf/wd-eva02-large-tagger-v3", subfolder=PurePosixPath(path.parent), filename=PurePosixPath(path.name), local_dir="data/models/WD14tagger" 47 | ) 48 | 49 | def make_square(img, target_size=None): 50 | old_size = img.shape[:2] 51 | desired_size = max(old_size) 52 | if target_size: 53 | desired_size = max(desired_size, target_size) 54 | 55 | delta_w = desired_size - old_size[1] 56 | delta_h = desired_size - old_size[0] 57 | top, bottom = delta_h // 2, delta_h - (delta_h // 2) 58 | left, right = delta_w // 2, delta_w - (delta_w // 2) 59 | 60 | color = [255, 255, 255] 61 | new_im = cv2.copyMakeBorder( 62 | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color 63 | ) 64 | return new_im 65 | 66 | def smart_resize(img, size): 67 | # Assumes the image has already gone through make_square 68 | if img.shape[0] > size: 69 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) 70 | elif img.shape[0] < size: 71 | img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) 72 | return img 73 | 74 | 75 | class Tagger: 76 | def __init__(self, general_threshold, character_threshold, with_confidence, is_danbooru_format): 77 | prepare_wd14tagger() 78 | 79 | self.model = timm.create_model('hf-hub:SmilingWolf/wd-eva02-large-tagger-v3', pretrained=False).eval() 80 | state_dict = timm.models.load_state_dict_from_hf("SmilingWolf/wd-eva02-large-tagger-v3") 81 | self.model.load_state_dict(state_dict) 82 | 83 | data_config = timm.data.resolve_data_config(self.model.pretrained_cfg, model=self.model) 84 | self.transform = timm.data.create_transform(**data_config, is_training=False) 85 | 86 | df = pd.read_csv("data/models/WD14tagger/selected_tags.csv") 87 | self.tag_names = df["name"].tolist() 88 | self.rating_indexes = list(np.where(df["category"] == 9)[0]) 89 | self.general_indexes = list(np.where(df["category"] == 0)[0]) 90 | self.character_indexes = list(np.where(df["category"] == 4)[0]) 91 | 92 | self.general_threshold = general_threshold 93 | self.character_threshold = character_threshold 94 | self.with_confidence = with_confidence 95 | self.is_danbooru_format = is_danbooru_format 96 | 97 | def __call__( 98 | self, 99 | image: Image, 100 | ): 101 | 102 | # Alpha to white 103 | image = image.convert("RGBA") 104 | new_image = Image.new("RGBA", image.size, "WHITE") 105 | new_image.paste(image, mask=image) 106 | image = new_image.convert("RGB") 107 | image = np.asarray(image) 108 | 109 | # PIL RGB to OpenCV BGR 110 | image = image[:, :, ::-1] 111 | image = make_square(image) 112 | 113 | image = self.transform( Image.fromarray(image) ).unsqueeze(0) 114 | 115 | self.model = self.model.to("cuda") 116 | image = image.to("cuda") 117 | 118 | probs = self.model.forward(image) 119 | probs = torch.nn.functional.sigmoid(probs) 120 | 121 | image = image.to("cpu") 122 | probs = probs.to("cpu") 123 | 124 | labels = list(zip(self.tag_names, probs.squeeze(0).numpy())) 125 | 126 | # First 4 labels are actually ratings: pick one with argmax 127 | ratings_names = [labels[i] for i in self.rating_indexes] 128 | rating = dict(ratings_names) 129 | 130 | # Then we have general tags: pick any where prediction confidence > threshold 131 | general_names = [labels[i] for i in self.general_indexes] 132 | general_res = [x for x in general_names if x[1] > self.general_threshold] 133 | general_res = dict(general_res) 134 | 135 | # Everything else is characters: pick any where prediction confidence > threshold 136 | character_names = [labels[i] for i in self.character_indexes] 137 | character_res = [x for x in character_names if x[1] > self.character_threshold] 138 | character_res = dict(character_res) 139 | 140 | #logger.info(f"{rating=}") 141 | #logger.info(f"{general_res=}") 142 | #logger.info(f"{character_res=}") 143 | 144 | #general_res = {k:general_res[k] for k in (general_res.keys() - set(self.ignore_tokens)) } 145 | #character_res = {k:character_res[k] for k in (character_res.keys() - set(self.ignore_tokens)) } 146 | 147 | prompt = "" 148 | 149 | if self.with_confidence: 150 | prompt = [ f"({i}:{character_res[i]:.2f})" for i in (character_res.keys()) ] 151 | prompt += [ f"({i}:{general_res[i]:.2f})" for i in (general_res.keys()) ] 152 | else: 153 | prompt = [ i for i in (character_res.keys()) ] 154 | prompt += [ i for i in (general_res.keys()) ] 155 | 156 | prompt = ",".join(prompt) 157 | 158 | if not self.is_danbooru_format: 159 | prompt = prompt.replace("_", " ") 160 | 161 | #logger.info(f"{prompt=}") 162 | return prompt 163 | 164 | def __del__(self): 165 | if self.model: 166 | self.model = self.model.to("cpu") 167 | 168 | 169 | def get_labels(frame_dir, general_threshold, character_threshold, with_confidence, is_danbooru_format): 170 | 171 | result = {} 172 | if os.path.isdir(frame_dir): 173 | 174 | png_list = get_image_file_list(frame_dir) 175 | 176 | with torch.no_grad(): 177 | tagger = Tagger(general_threshold, character_threshold, with_confidence, is_danbooru_format) 178 | 179 | for p in tqdm( png_list, desc=f"WD14tagger"): 180 | result[p] = tagger( 181 | image= Image.open(p) 182 | ) 183 | 184 | tagger = None 185 | 186 | torch.cuda.empty_cache() 187 | 188 | return result 189 | 190 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SD Batch Runner 2 | 3 | This is an application for batch execution of [Stable Diffusion Web UI(A1111)](https://github.com/AUTOMATIC1111/stable-diffusion-webui) launched on a local PC via api. 4 | (This is NOT sdwebui extension.) 5 | 6 | It is easy to do the following. 7 | - Image generation by simply selecting lora without manually setting trigger word or lbw 8 | - Batch generation of images with only the character or style changed with specific settings 9 | - Easy to try multiple lora combinations. (ex. Generate a combination of a specific style and all character loras) 10 | - Complex image generation procedures, such as those involving controlnets or consecutive txt2img,img2img, can be saved and reused at a later date. 11 | (ex. After generating txt2img once, use it as an input image and execute img2img 3 times with slightly different conditions, and execute this procedure 20 times, changing the character and style each time. ) 12 | 13 | 14 | 15 | 16 | ## Installation(for windows) 17 | [Python 3.10](https://www.python.org/) and git client must be installed 18 | 19 | ```sh 20 | git clone https://github.com/s9roll7/sd_batch_runner.git 21 | cd sd_batch_runner 22 | py -3.10 -m venv venv 23 | venv\Scripts\activate.bat 24 | # Please install torch according to your environment.(https://pytorch.org/get-started/locally/) 25 | python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | ## Preparation on the sdwebui side 30 | ### Installation 31 | 32 | This is a client application using the sdwebui api, so it is necessary to launch sdwebui beforehand. 33 | - [sdwebui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) (tested with v1.10.1) 34 | By default sdwebui does not have api enabled!!! Please start with api enabled. 35 | https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API 36 | ```sh 37 | First, of course, is to run webui with --api commandline argument 38 | example in your "webui-user.bat": set COMMANDLINE_ARGS=--api 39 | ``` 40 | 41 | Install the following extensions. 42 | - Stable-Diffusion-Webui-Civitai-Helper(https://github.com/butaixianran/Stable-Diffusion-Webui-Civitai-Helper) 43 | 44 | - adetailer(https://github.com/Bing-su/adetailer) 45 | 46 | - sd-webui-controlnet(https://github.com/Mikubill/sd-webui-controlnet) 47 | It is necessary to download the controlnet model corresponding to the model you want to use. 48 | 49 | - sd-webui-lora-block-weight(https://github.com/hako-mikan/sd-webui-lora-block-weight) 50 | You must set the preset. Below is an example of a preset when using pony/sdxl 51 | 52 | ```sh 53 | Face-1:1,0,0,0,0,0,1,1,1,0,0,0 54 | Face-2:1,0,1,0,1,0,0.8,1,1,0.6,0,1 55 | Face-3:1,0,0,0,1,0.3,1,1,1,0.8,0,0.4 56 | Noface:1,1,1,1,0,1,0,0,0,1,1,1 57 | Wear:1,1,1,0,0,0,1,1,1,0,0,0 58 | Pose:1,0,0,0,0,1,1,1,0,0,0,0 59 | ArtStyle:1,0,0,0,0,0,0,0,0,1,1,1 60 | Char:1,1,1,0,0,0,1,1,1,1,1,1 61 | BG:1,1,1,1,0,1,0,0,0,1,1,1 62 | Soft:1,0,0,0,0,0,1,1,1,1,1,1 63 | ``` 64 | 65 | 66 | The following extensions are low priority. 67 | 68 | 69 | - sd-webui-segment-anything(https://github.com/continue-revolution/sd-webui-segment-anything) 70 | You need to download one of the sam models and one of the DINO models. 71 | Please note that this extension has its own installation procedure. You need to set the option to use local GroundingDINO in the options settings. 72 | ```sh 73 | Due to the overwhelming complaints about GroundingDINO installation and the lack of substitution of similar high-performance text-to-bounding-box library, 74 | I decide to modify the source code of GroundingDINO and push to this repository. 75 | Starting from v1.5.0, you can choose to use local GroundingDINO by checking Use local groundingdino to bypass C++ problem on Settings/Segment Anything. 76 | This change should solve all problems about ninja, pycocotools, _C and any other problems related to C++/CUDA compilation. 77 | ``` 78 | 79 | - stable-diffusion-webui-promptgen(https://github.com/davidmartinrius/stable-diffusion-webui-promptgen/tree/api-implementation) 80 | Please note that you need to install the branch that contains the api support(api-implementation branch), not the original created by AUTOMATIC1111. 81 | https://github.com/mix1009/sdwebuiapi?tab=readme-ov-file#prompt-generator-api-by-david-martin-rius 82 | 83 | 84 | ### Organizing Lora Files 85 | This application assumes that lora files are organized in subfolders by type. 86 | The following is an example. Some files may be difficult to classify, but there is no need to be so strict. 87 | (If there are too many files to organize, you can put only the files you want to use in a subfolder.) 88 | 89 | ```sh 90 | YOUR_SD_PATH/models/Lora/Pony/character 91 | YOUR_SD_PATH/models/Lora/Pony/style 92 | YOUR_SD_PATH/models/Lora/Pony/pose 93 | YOUR_SD_PATH/models/Lora/Pony/item 94 | YOUR_SD_PATH/models/Lora/Pony/etc 95 | YOUR_SD_PATH/models/Lora/SDXL/character 96 | YOUR_SD_PATH/models/Lora/SDXL/style 97 | YOUR_SD_PATH/models/Lora/SDXL/pose 98 | YOUR_SD_PATH/models/Lora/SDXL/item 99 | YOUR_SD_PATH/models/Lora/SDXL/etc 100 | ``` 101 | Restart sdwebui when you have finished organizing the files. 102 | Then, use the Stable-Diffusion-Webui-Civitai-Helper you just installed to collect the lora preview image and trigger word and other information. 103 | (Civitai Helper Tab -> "Scan Models for Civitai" -> press "Scan" button) 104 | 105 | 106 | ## How To Use 107 | 108 | run sdwebui 109 | 110 | 111 | - GUI 112 | run launch_GUI.bat 113 | (Perform [Initial Setting](#initial-settings) when starting up for the first time.) 114 | Go [Generate -> Create New Sequence -> common -> prompt] 115 | Select some lora files. By default, one character lora is selected at random. 116 | Go [Generate -> Create New Sequence -> common -> prompt -> preset_tags] 117 | Select quality tag. 118 | Go [Generate -> Create New Sequence -> common -> prompt -> header(or footer)] 119 | Enter prompt. 120 | Go [bottom left of app screen -> batch_count] 121 | For now, set about 4. 122 | Go [bottom left of app screen -> Generate] 123 | 124 | 125 | 126 | 127 | - CUI 128 | run launch_cmd.bat 129 | 130 | ```sh 131 | python main.py 132 | ``` 133 | 134 | ### Works with [Speech Bubble Remove and Copy Tool](https://github.com/s9roll7/speech_bubble_remove_and_copy) 135 | Generate clean images with the tool. 136 | Create a directory containing only the images you want to use. 137 | Generate tags from images. 138 | Remove unwanted tags. 139 | Select txt2img or img2img. 140 | (Optional) Set controlnet parameters. 141 | Press Convert. 142 | Load the generated json file. 143 | Configure lora and prompt settings and run generation. 144 | Copy the generated image to "YOUR_PROJECT_DIR/base" 145 | 146 | 147 | 148 | 149 | 150 | ## Initial Settings 151 | There are a few items to be set up at first startup. 152 | - Lora Directory Setting 153 | Set the path of the directory containing lora. Multiple environments can be set up, but only one is needed at first. 154 | - Config -> main 155 | Select a stable diffusion checkpoint file. 156 | Select the environment name set in [Lora Directory Setting]. 157 | - Generate -> Lora Update 158 | Run every time you add or remove a lora file. 159 | 160 | 161 | ## Advanced Settings 162 | - Controlnet Alias Setting 163 | Several settings are included as samples. Set the model if you want to use them. 164 | - Config -> lora_generate_tag 165 | Feature to automatically add train tags extracted from safetensors to the prompt. 166 | If you find the automatically added tags intrusive, please adjust them. 167 | - Config -> lora_block_weight 168 | If you do not need lbw, disable it. 169 | It is safer to use it when combining multiple lora, as it has the effect of making the image less likely to break up. 170 | - Config -> adetailer 171 | If you want to generate even higher quality images, set [person_yolov8n] to 1 and [face_yolov8n] to 2. 172 | - Config -> segment_anything 173 | Set sam model. 174 | Set dino model. 175 | ( Retrieving the dino model name is not supported by the api, so you need to enter the model name by text. 176 | See sdwebui txt2img -> Generation -> Segment Anything -> Enable GroundingDINO -> GroundingDINO Model ) 177 | 178 | 179 | ## Changelog 180 | ### 2024-12-22 181 | Added automatic generation of sequence file from image directories. 182 | Fixed a bug that dynamic prompt could not be used together. 183 | UI Improvements. 184 | Some bug fix 185 | 186 | 187 | ### Related resources 188 | - [Stable Diffusion Web UI(A1111)](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 189 | - [sdwebuiapi](https://github.com/mix1009/sdwebuiapi) 190 | - [Stable-Diffusion-Webui-Civitai-Helper](https://github.com/butaixianran/Stable-Diffusion-Webui-Civitai-Helper) 191 | - [adetailer](https://github.com/Bing-su/adetailer) 192 | - [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet) 193 | - [sd-webui-lora-block-weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) 194 | - [sd-webui-segment-anything](https://github.com/continue-revolution/sd-webui-segment-anything) 195 | - [stable-diffusion-webui-promptgen](https://github.com/AUTOMATIC1111/stable-diffusion-webui-promptgen) 196 | - [stable-diffusion-webui-promptgen](https://github.com/davidmartinrius/stable-diffusion-webui-promptgen/tree/api-implementation) 197 | - [wd-eva02-large-tagger-v3](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3) 198 | 199 | -------------------------------------------------------------------------------- /sd_batch_runner/lora.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from pathlib import Path 5 | from datetime import datetime 6 | import math 7 | import random 8 | import re 9 | from enum import Enum 10 | 11 | 12 | from sd_batch_runner.util import * 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.INFO) 17 | formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)s %(funcName)s [%(levelname)s]: %(message)s") 18 | handler = logging.StreamHandler() 19 | handler.setFormatter(formatter) 20 | logger.addHandler(handler) 21 | 22 | 23 | re_word = re.compile(r"[-_\w']+") 24 | 25 | 26 | 27 | class LoraType(str, Enum): 28 | Item = "item" 29 | Pose = "pose" 30 | Style = "style" 31 | Character = "char" 32 | All = "all" 33 | 34 | 35 | class Lora(): 36 | 37 | static_instance_map={} 38 | 39 | def __init__(self, user_data_path:Path, data_path:Path, tag_th, prohibited_tags, enable_gen_tag): 40 | 41 | r = config_get_lora_dir_env_root_path() 42 | user_data_path = r / user_data_path 43 | data_path = r / data_path 44 | 45 | self.data={} 46 | if data_path.is_file(): 47 | with open(data_path, "r", encoding="utf-8") as f: 48 | self.data = json.load(f) 49 | self.user_data={} 50 | if user_data_path.is_file(): 51 | with open(user_data_path, "r", encoding="utf-8") as f: 52 | self.user_data = json.load(f) 53 | 54 | self.default_weight = 1.0 55 | self.tag_th = tag_th 56 | self.prohibited_tags = prohibited_tags 57 | self.enable_gen_tag = enable_gen_tag 58 | 59 | self.random_order_cache = {} 60 | 61 | def _select(self, key): 62 | 63 | logger.info(f"lora select {key=}") 64 | 65 | cur_item = self.data.get(key, {}) 66 | 67 | if not cur_item: 68 | return None 69 | 70 | user_item = self.user_data.get(key, {}) 71 | for k in user_item: 72 | cur_item[k] = user_item[k] 73 | 74 | w = cur_item.get("weight", self.default_weight) 75 | 76 | a = Path(key).stem 77 | b = w 78 | c = "" 79 | 80 | if user_item: 81 | trigger = user_item.get("trigger", []) 82 | if trigger: 83 | index = random.randrange(0, len(trigger)) 84 | c = trigger[index] 85 | 86 | if not c: 87 | tags = cur_item.get("tags", []) 88 | if self.enable_gen_tag and tags: 89 | c = generate_prompt_from_tags(tags, self.tag_th, self.prohibited_tags) 90 | else: 91 | trigger = cur_item.get("trigger", []) 92 | if trigger: 93 | index = random.randrange(0, len(trigger)) 94 | c = trigger[index] 95 | 96 | return [a,b,c] 97 | 98 | def _select_one(self, filter): 99 | item_list = self.get_file_list() 100 | if not item_list: 101 | return None 102 | 103 | if (filter not in self.random_order_cache) or (not self.random_order_cache[filter]): 104 | if filter != "": 105 | filter_list = filter.split("|") 106 | filter_list = [f for f in filter_list if f] 107 | 108 | result = [] 109 | for f in filter_list: 110 | result += [name for name in item_list if name.lower().find(f) != -1] 111 | 112 | item_list = list(dict.fromkeys(result)) 113 | 114 | if len(item_list) == 0: 115 | return None 116 | else: 117 | random.shuffle(item_list) 118 | self.random_order_cache[filter] = item_list 119 | 120 | 121 | key = self.random_order_cache[filter].pop(0) 122 | 123 | return self._select(key) 124 | 125 | def _clear_filter(self): 126 | self.random_order_cache.clear() 127 | 128 | def get_file_dir(self): 129 | raise NotImplementedError() 130 | 131 | def get_file_list(self): 132 | tmp = list(self.data.keys()) 133 | tmp.sort() 134 | return tmp 135 | 136 | @classmethod 137 | def _create(cls, lora_type:LoraType): 138 | if lora_type not in Lora.static_instance_map: 139 | if lora_type == LoraType.Item: 140 | Lora.static_instance_map[lora_type] = ItemLora() 141 | elif lora_type == LoraType.Pose: 142 | Lora.static_instance_map[lora_type] = PoseLora() 143 | elif lora_type == LoraType.Style: 144 | Lora.static_instance_map[lora_type] = StyleLora() 145 | elif lora_type == LoraType.Character: 146 | Lora.static_instance_map[lora_type] = CharacterLora() 147 | 148 | @classmethod 149 | def select(cls, lora_type:LoraType, key): 150 | if lora_type == LoraType.All: 151 | types = [LoraType.Item,LoraType.Pose,LoraType.Style,LoraType.Character] 152 | else: 153 | types = [lora_type] 154 | 155 | for t in types: 156 | Lora._create(t) 157 | 158 | result = None 159 | for t in types: 160 | result = Lora.static_instance_map[t]._select(key) 161 | if result != None: 162 | break 163 | 164 | return result 165 | 166 | @classmethod 167 | def select_one(cls, lora_type:LoraType, filter=""): 168 | if lora_type == LoraType.All: 169 | types = [LoraType.Item,LoraType.Pose,LoraType.Style,LoraType.Character] 170 | else: 171 | types = [lora_type] 172 | 173 | for t in types: 174 | Lora._create(t) 175 | 176 | result = None 177 | for t in types: 178 | result = Lora.static_instance_map[t]._select_one(filter) 179 | if result != None: 180 | break 181 | 182 | return result 183 | 184 | @classmethod 185 | def create_instance(cls, lora_type:LoraType): 186 | if lora_type == LoraType.Item: 187 | return ItemLora() 188 | elif lora_type == LoraType.Pose: 189 | return PoseLora() 190 | elif lora_type == LoraType.Style: 191 | return StyleLora() 192 | elif lora_type == LoraType.Character: 193 | return CharacterLora() 194 | else: 195 | raise NotImplementedError() 196 | 197 | def lora_clear_cache(): 198 | Lora.static_instance_map.clear() 199 | 200 | 201 | class ItemLora(Lora): 202 | def __init__(self): 203 | tag_th = config_get_lora_generate_tag_th_item() 204 | prohibited_tags = config_get_lora_generate_tag_prohibited_tags_item() 205 | enable_gen_tag = config_get_lora_generate_tag_enable_item() 206 | super().__init__(Path("user_lora.json"), Path("item_lora.json"), tag_th, prohibited_tags, enable_gen_tag) 207 | def get_type(self): 208 | return LoraType.Item 209 | def get_file_dir(self): 210 | return config_get_item_lora_dir_path() 211 | 212 | class PoseLora(Lora): 213 | def __init__(self): 214 | tag_th = config_get_lora_generate_tag_th_pose() 215 | prohibited_tags = config_get_lora_generate_tag_prohibited_tags_pose() 216 | enable_gen_tag = config_get_lora_generate_tag_enable_pose() 217 | super().__init__(Path("user_lora.json"), Path("pose_lora.json"), tag_th, prohibited_tags, enable_gen_tag) 218 | def get_type(self): 219 | return LoraType.Pose 220 | def get_file_dir(self): 221 | return config_get_pose_lora_dir_path() 222 | 223 | class StyleLora(Lora): 224 | def __init__(self): 225 | tag_th = config_get_lora_generate_tag_th_style() 226 | prohibited_tags = config_get_lora_generate_tag_prohibited_tags_style() 227 | enable_gen_tag = config_get_lora_generate_tag_enable_style() 228 | super().__init__(Path("user_lora.json"), Path("style_lora.json"), tag_th, prohibited_tags, enable_gen_tag) 229 | def get_type(self): 230 | return LoraType.Style 231 | def get_file_dir(self): 232 | return config_get_style_lora_dir_path() 233 | 234 | class CharacterLora(Lora): 235 | def __init__(self): 236 | tag_th = config_get_lora_generate_tag_th_character() 237 | prohibited_tags = config_get_lora_generate_tag_prohibited_tags_character() 238 | enable_gen_tag = config_get_lora_generate_tag_enable_character() 239 | super().__init__(Path("user_lora.json"), Path("character_lora.json"), tag_th, prohibited_tags, enable_gen_tag) 240 | def get_type(self): 241 | return LoraType.Character 242 | def get_file_dir(self): 243 | return config_get_character_lora_dir_path() 244 | 245 | 246 | def generate_prompt_from_tags(tags, th = -1, prohibited_tags=[]): 247 | max_count = None 248 | res = [] 249 | for tag, count in tags: 250 | if not max_count: 251 | max_count = count 252 | 253 | if tag in prohibited_tags: 254 | logger.debug(f"ignore {tag=}") 255 | continue 256 | 257 | if th < 0: 258 | v = random.random() * max_count 259 | else: 260 | v = th * max_count 261 | if count > v: 262 | for x in "({[]})": 263 | tag = tag.replace(x, '\\' + x) 264 | res.append(tag) 265 | 266 | return ", ".join(sorted(res)) 267 | 268 | # from https://github.com/AUTOMATIC1111/stable-diffusion-webui/tree/master/modules/sd_models.py 269 | def get_train_tags_from_safetensors(file_path): 270 | 271 | def read_metadata_from_safetensors(filename): 272 | with open(filename, mode="rb") as file: 273 | metadata_len = file.read(8) 274 | metadata_len = int.from_bytes(metadata_len, "little") 275 | json_start = file.read(2) 276 | 277 | assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file" 278 | 279 | res = {} 280 | 281 | try: 282 | json_data = json_start + file.read(metadata_len-2) 283 | json_obj = json.loads(json_data) 284 | for k, v in json_obj.get("__metadata__", {}).items(): 285 | res[k] = v 286 | if isinstance(v, str) and v[0:1] == '{': 287 | try: 288 | res[k] = json.loads(v) 289 | except Exception: 290 | pass 291 | except Exception: 292 | logger.error(f"Error reading metadata from file: {filename}") 293 | 294 | return res 295 | 296 | def build_tags(metadata): 297 | def is_non_comma_tagset(tags): 298 | average_tag_length = sum(len(x) for x in tags.keys()) / len(tags) 299 | return average_tag_length >= 16 300 | 301 | tags = {} 302 | 303 | ss_tag_frequency = metadata.get("ss_tag_frequency", {}) 304 | if ss_tag_frequency is not None and hasattr(ss_tag_frequency, 'items'): 305 | for _, tags_dict in ss_tag_frequency.items(): 306 | for tag, tag_count in tags_dict.items(): 307 | tag = tag.strip() 308 | tags[tag] = tags.get(tag, 0) + int(tag_count) 309 | 310 | if tags and is_non_comma_tagset(tags): 311 | new_tags = {} 312 | 313 | for text, text_count in tags.items(): 314 | for word in re.findall(re_word, text): 315 | if len(word) < 3: 316 | continue 317 | 318 | new_tags[word] = new_tags.get(word, 0) + text_count 319 | 320 | tags = new_tags 321 | 322 | ordered_tags = sorted(tags.keys(), key=tags.get, reverse=True) 323 | 324 | return [(tag, tags[tag]) for tag in ordered_tags] 325 | 326 | ####################################### 327 | 328 | metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} 329 | 330 | metadata = read_metadata_from_safetensors(file_path) 331 | 332 | if metadata: 333 | m = {} 334 | for k, v in sorted(metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): 335 | m[k] = v 336 | 337 | return build_tags(m) 338 | else: 339 | return [] 340 | 341 | 342 | def update_lora(lora_dir_path:Path, json_path:Path, is_overwrite): 343 | 344 | data = {} 345 | 346 | if json_path.is_file(): 347 | with open(json_path, "r", encoding="utf-8") as f: 348 | data = json.load(f) 349 | 350 | def search_weight(info, lora_name:str): 351 | imgs = info.get("images",[]) 352 | for img in imgs: 353 | meta = img.get("meta",{}) 354 | if meta: 355 | ress = meta.get("resources", []) 356 | for res in ress: 357 | name = res.get("name", "") 358 | if lora_name.startswith(name): 359 | weight = res.get("weight", None) 360 | if weight is not None: 361 | return weight 362 | 363 | result = [] 364 | 365 | safetensors_list = [ p for p in lora_dir_path.glob("**/*") if p.suffix in (".safetensors")] 366 | 367 | for p in safetensors_list: 368 | data_exist = p.name in data 369 | 370 | if is_overwrite == False: 371 | if data_exist: 372 | result.append((p.name , "skip")) 373 | continue 374 | 375 | trigger = [] 376 | updated_date = "" 377 | name = "" 378 | 379 | info = {} 380 | if p.with_suffix(".civitai.info").is_file(): 381 | with open(p.with_suffix(".civitai.info"), "r") as f: 382 | info = json.load(f) 383 | 384 | trigger = info.get("trainedWords", []) 385 | updated_date = info.get("updatedAt", "") 386 | name = info.get("name", "") 387 | 388 | 389 | tags = get_train_tags_from_safetensors(p) 390 | 391 | 392 | data[p.name] = { 393 | "name" : name, 394 | "updated_date" : updated_date, 395 | "trigger": trigger, 396 | "tags" : tags 397 | } 398 | 399 | w = search_weight(info, p.stem) 400 | if w is not None: 401 | data[p.name]["weight"] = w 402 | 403 | result.append((p.name , "update" if data_exist else "add")) 404 | 405 | safetensors_list = [ p.name for p in safetensors_list] 406 | key_list = list(data.keys()) 407 | 408 | key_only = set(key_list) - set(safetensors_list) 409 | key_only = list(key_only) 410 | 411 | for k in key_only: 412 | data.pop(k) 413 | result.append((k , "delete")) 414 | 415 | 416 | 417 | json_text = json.dumps(data, indent=4, ensure_ascii=False) 418 | json_path.write_text(json_text, encoding="utf-8") 419 | 420 | return result, len(data.keys()) 421 | 422 | def print_update_result(text, result, total): 423 | logger.info(text) 424 | for r in result: 425 | if r[1] == "add": 426 | logger.info(f"{r[1]} [{r[0]}]") 427 | for r in result: 428 | if r[1] == "delete": 429 | logger.info(f"{r[1]} [{r[0]}]") 430 | logger.info(f"Total : {total}") 431 | 432 | def show_lora_key(json_path): 433 | data = {} 434 | 435 | if json_path.is_file(): 436 | with open(json_path, "r", encoding="utf-8") as f: 437 | data = json.load(f) 438 | 439 | for i, key in enumerate(data.keys()): 440 | logger.info(f"[{i}] {key}") 441 | 442 | 443 | 444 | ############################################################### 445 | 446 | def update_lora_command(is_overwrite): 447 | 448 | logger.info(f"{is_overwrite=}") 449 | 450 | r = config_get_lora_dir_env_root_path() 451 | 452 | lora_path = config_get_character_lora_dir_path() 453 | json_path = r / Path("character_lora.json") 454 | result1,total1 = update_lora(lora_path, json_path, is_overwrite) 455 | 456 | lora_path = config_get_style_lora_dir_path() 457 | json_path = r / Path("style_lora.json") 458 | result2,total2 = update_lora(lora_path, json_path, is_overwrite) 459 | 460 | lora_path = config_get_pose_lora_dir_path() 461 | json_path = r / Path("pose_lora.json") 462 | result3,total3 = update_lora(lora_path, json_path, is_overwrite) 463 | 464 | lora_path = config_get_item_lora_dir_path() 465 | json_path = r / Path("item_lora.json") 466 | result4,total4 = update_lora(lora_path, json_path, is_overwrite) 467 | 468 | env = config_get_current_lora_dir_env() 469 | logger.info(f"[lora env = {env}]") 470 | print_update_result("== character lora ==", result1,total1) 471 | print_update_result("== style lora ==", result2,total2) 472 | print_update_result("== pose lora ==", result3,total3) 473 | print_update_result("== item lora ==", result4,total4) 474 | 475 | 476 | def show_lora_command(): 477 | r = config_get_lora_dir_env_root_path() 478 | env = config_get_current_lora_dir_env() 479 | logger.info(f"[lora env = {env}]") 480 | 481 | logger.info(f"== character lora ==") 482 | show_lora_key(r / Path("character_lora.json")) 483 | logger.info(f"== style lora ==") 484 | show_lora_key(r / Path("style_lora.json")) 485 | logger.info(f"== pose lora ==") 486 | show_lora_key(r / Path("pose_lora.json")) 487 | logger.info(f"== item lora ==") 488 | show_lora_key(r / Path("item_lora.json")) 489 | 490 | def show_lora_env_command(): 491 | envs = config_get_lora_dir_env_list() 492 | for i, ev in enumerate(envs): 493 | logger.info(f"{i} : {ev}") 494 | 495 | env = config_get_current_lora_dir_env() 496 | logger.info(f"current [lora env = {env}]") 497 | 498 | def set_lora_env_command(new_env): 499 | env = config_get_current_lora_dir_env() 500 | logger.info(f"[lora env = {env}]") 501 | 502 | envs = config_get_lora_dir_env_list() 503 | if new_env not in envs: 504 | raise ValueError(f"{new_env=} is Not listed in lora_dir.json") 505 | 506 | config_set_current_lora_dir_env(new_env) 507 | 508 | env = config_get_current_lora_dir_env() 509 | logger.info(f"-> [lora env = {env}]") 510 | 511 | 512 | def get_thumb_path(lora_type:LoraType, lora): 513 | 514 | if lora_type == LoraType.Character: 515 | lora_path = config_get_character_lora_dir_path() 516 | elif lora_type == LoraType.Style: 517 | lora_path = config_get_style_lora_dir_path() 518 | elif lora_type == LoraType.Pose: 519 | lora_path = config_get_pose_lora_dir_path() 520 | elif lora_type == LoraType.Item: 521 | lora_path = config_get_item_lora_dir_path() 522 | 523 | lora_path = lora_path / Path(lora) 524 | img_path = lora_path.with_suffix(".preview.png") 525 | return img_path 526 | 527 | def get_lora_files_and_thumbs(lora_type:LoraType, thumb_size): 528 | 529 | lora_obj = Lora.create_instance(lora_type) 530 | 531 | file_list = lora_obj.get_file_list() 532 | lora_dir = lora_obj.get_file_dir() 533 | 534 | result = [] 535 | for f in file_list: 536 | p = Path(lora_dir) / Path(f) 537 | 538 | if p.is_file(): 539 | thumb = None 540 | thumb_path = p.with_suffix(".preview.png") 541 | if thumb_path.is_file(): 542 | thumb = get_thumb(thumb_path, thumb_size) 543 | 544 | result.append((p.stem, p.stat().st_ctime, thumb)) 545 | 546 | return result 547 | 548 | def get_lora_files_and_preview_paths(lora_type:LoraType): 549 | 550 | lora_obj = Lora.create_instance(lora_type) 551 | 552 | file_list = lora_obj.get_file_list() 553 | lora_dir = lora_obj.get_file_dir() 554 | 555 | result = [] 556 | for f in file_list: 557 | p = Path(lora_dir) / Path(f) 558 | 559 | if p.is_file(): 560 | thumb = None 561 | thumb_path = p.with_suffix(".preview.png") 562 | if thumb_path.is_file(): 563 | thumb = thumb_path 564 | 565 | result.append((p.stem, thumb)) 566 | 567 | return result 568 | 569 | -------------------------------------------------------------------------------- /sd_batch_runner/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from pathlib import Path 5 | from datetime import datetime 6 | import shutil 7 | import random 8 | import re 9 | 10 | import cv2 11 | import numpy as np 12 | from PIL import Image 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.INFO) 17 | formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)s %(funcName)s [%(levelname)s]: %(message)s") 18 | handler = logging.StreamHandler() 19 | handler.setFormatter(formatter) 20 | logger.addHandler(handler) 21 | 22 | 23 | ####################################################################### 24 | def config_clear_cache(): 25 | global _conf,_cn_conf,_lora_conf,_file_list_cache,_preset_tags_conf 26 | _conf = {} 27 | _cn_conf = {} 28 | _lora_conf = {} 29 | _file_list_cache = {} 30 | _preset_tags_conf = {} 31 | 32 | 33 | 34 | ####################################################################### 35 | ## config 36 | 37 | CONFIG_FILE_PATH = "config.json" 38 | _conf = {} 39 | 40 | def get_config_dict(): 41 | global _conf 42 | if not _conf: 43 | with open(CONFIG_FILE_PATH, "r", encoding="utf-8") as f: 44 | _conf = json.load(f) 45 | 46 | return _conf 47 | 48 | def update_config(): 49 | if _conf: 50 | json_text = json.dumps(_conf, indent=4, ensure_ascii=False) 51 | Path(CONFIG_FILE_PATH).write_text(json_text, encoding="utf-8") 52 | 53 | def config_get_default_checkpoint(): 54 | c = get_config_dict() 55 | return c["default_checkpoint"] 56 | def config_set_default_checkpoint(new_val): 57 | c = get_config_dict() 58 | c["default_checkpoint"] = new_val 59 | update_config() 60 | 61 | 62 | def config_set_current_lora_dir_env(new_val): 63 | c = get_config_dict() 64 | c["lora_dir_env"] = new_val 65 | update_config() 66 | 67 | def config_get_current_lora_dir_env(): 68 | c = get_config_dict() 69 | return c["lora_dir_env"] 70 | 71 | def config_get_lora_dir_env_root_path(): 72 | c = get_config_dict() 73 | p = Path("lora_dir_env") / Path(c["lora_dir_env"]) 74 | p.mkdir(parents=True, exist_ok=True) 75 | return p 76 | 77 | 78 | 79 | def config_get_default_generation_setting(is_txt2img): 80 | c = get_config_dict() 81 | if is_txt2img: 82 | return dict(**c["generation_setting_common"], **c["generation_setting_txt2img"]) 83 | else: 84 | return dict(**c["generation_setting_common"], **c["generation_setting_img2img"]) 85 | 86 | 87 | 88 | def config_get_lora_generate_tag_enable_character(): 89 | c = get_config_dict() 90 | return c["lora_generate_tag"]["enable_character"] 91 | 92 | def config_get_lora_generate_tag_enable_style(): 93 | c = get_config_dict() 94 | return c["lora_generate_tag"]["enable_style"] 95 | 96 | def config_get_lora_generate_tag_enable_pose(): 97 | c = get_config_dict() 98 | return c["lora_generate_tag"]["enable_pose"] 99 | 100 | def config_get_lora_generate_tag_enable_item(): 101 | c = get_config_dict() 102 | return c["lora_generate_tag"]["enable_item"] 103 | 104 | 105 | def config_get_lora_generate_tag_th_character(): 106 | c = get_config_dict() 107 | return c["lora_generate_tag"]["tag_th_character"] 108 | 109 | def config_get_lora_generate_tag_th_style(): 110 | c = get_config_dict() 111 | return c["lora_generate_tag"]["tag_th_style"] 112 | 113 | def config_get_lora_generate_tag_th_pose(): 114 | c = get_config_dict() 115 | return c["lora_generate_tag"]["tag_th_pose"] 116 | 117 | def config_get_lora_generate_tag_th_item(): 118 | c = get_config_dict() 119 | return c["lora_generate_tag"]["tag_th_item"] 120 | 121 | def config_get_lora_generate_tag_prohibited_tags_character(): 122 | c = get_config_dict() 123 | return c["lora_generate_tag"]["prohibited_tags_character"] 124 | 125 | def config_get_lora_generate_tag_prohibited_tags_style(): 126 | c = get_config_dict() 127 | return c["lora_generate_tag"]["prohibited_tags_style"] 128 | 129 | def config_get_lora_generate_tag_prohibited_tags_pose(): 130 | c = get_config_dict() 131 | return c["lora_generate_tag"]["prohibited_tags_pose"] 132 | 133 | def config_get_lora_generate_tag_prohibited_tags_item(): 134 | c = get_config_dict() 135 | return c["lora_generate_tag"]["prohibited_tags_item"] 136 | 137 | 138 | def config_get_lbw_enable_character(lora_index): 139 | c = get_config_dict() 140 | lora_type = "character" if lora_index==0 else "character2" 141 | return c["lora_block_weight"][lora_type]["enable_lbw"] 142 | def config_get_lbw_preset_character(lora_index): 143 | c = get_config_dict() 144 | lora_type = "character" if lora_index==0 else "character2" 145 | return c["lora_block_weight"][lora_type]["preset"] 146 | def config_get_lbw_start_stop_step_character(lora_index): 147 | c = get_config_dict() 148 | lora_type = "character" if lora_index==0 else "character2" 149 | return c["lora_block_weight"][lora_type]["start_stop_step"] 150 | def config_get_lbw_start_stop_step_value_character(lora_index): 151 | c = get_config_dict() 152 | lora_type = "character" if lora_index==0 else "character2" 153 | return c["lora_block_weight"][lora_type]["start_stop_step_value"] 154 | 155 | def config_get_lbw_enable_style(lora_index): 156 | c = get_config_dict() 157 | lora_type = "style" if lora_index==0 else "style2" 158 | return c["lora_block_weight"][lora_type]["enable_lbw"] 159 | def config_get_lbw_preset_style(lora_index): 160 | c = get_config_dict() 161 | lora_type = "style" if lora_index==0 else "style2" 162 | return c["lora_block_weight"][lora_type]["preset"] 163 | def config_get_lbw_start_stop_step_style(lora_index): 164 | c = get_config_dict() 165 | lora_type = "style" if lora_index==0 else "style2" 166 | return c["lora_block_weight"][lora_type]["start_stop_step"] 167 | def config_get_lbw_start_stop_step_value_style(lora_index): 168 | c = get_config_dict() 169 | lora_type = "style" if lora_index==0 else "style2" 170 | return c["lora_block_weight"][lora_type]["start_stop_step_value"] 171 | 172 | def config_get_lbw_enable_pose(lora_index): 173 | c = get_config_dict() 174 | lora_type = "pose" if lora_index==0 else "pose2" 175 | return c["lora_block_weight"][lora_type]["enable_lbw"] 176 | def config_get_lbw_preset_pose(lora_index): 177 | c = get_config_dict() 178 | lora_type = "pose" if lora_index==0 else "pose2" 179 | return c["lora_block_weight"][lora_type]["preset"] 180 | def config_get_lbw_start_stop_step_pose(lora_index): 181 | c = get_config_dict() 182 | lora_type = "pose" if lora_index==0 else "pose2" 183 | return c["lora_block_weight"][lora_type]["start_stop_step"] 184 | def config_get_lbw_start_stop_step_value_pose(lora_index): 185 | c = get_config_dict() 186 | lora_type = "pose" if lora_index==0 else "pose2" 187 | return c["lora_block_weight"][lora_type]["start_stop_step_value"] 188 | 189 | def config_get_lbw_enable_item(lora_index): 190 | c = get_config_dict() 191 | lora_type = "item" if lora_index==0 else "item2" 192 | return c["lora_block_weight"][lora_type]["enable_lbw"] 193 | def config_get_lbw_preset_item(lora_index): 194 | c = get_config_dict() 195 | lora_type = "item" if lora_index==0 else "item2" 196 | return c["lora_block_weight"][lora_type]["preset"] 197 | def config_get_lbw_start_stop_step_item(lora_index): 198 | c = get_config_dict() 199 | lora_type = "item" if lora_index==0 else "item2" 200 | return c["lora_block_weight"][lora_type]["start_stop_step"] 201 | def config_get_lbw_start_stop_step_value_item(lora_index): 202 | c = get_config_dict() 203 | lora_type = "item" if lora_index==0 else "item2" 204 | return c["lora_block_weight"][lora_type]["start_stop_step_value"] 205 | 206 | 207 | 208 | def config_get_adetailer_setting(): 209 | c = get_config_dict() 210 | return c["adetailer"] 211 | 212 | def config_get_default_prompt_gen_setting(): 213 | c = get_config_dict() 214 | return c["prompt_gen_setting"] 215 | 216 | def config_get_default_overwrite_generation_setting(): 217 | c = get_config_dict() 218 | return c["overwrite_generation_setting"] 219 | 220 | def config_get_segment_anything_sam_model_name(): 221 | c = get_config_dict() 222 | return c["segment_anything"]["sam_model_name"] 223 | 224 | def config_get_segment_anything_dino_model_name(): 225 | c = get_config_dict() 226 | return c["segment_anything"]["dino_model_name"] 227 | 228 | 229 | 230 | 231 | ####################################################################### 232 | ## controlnet 233 | 234 | CONTROLNET_FILE_PATH = "controlnet.json" 235 | _cn_conf = {} 236 | 237 | def get_cn_config_dict(): 238 | global _cn_conf 239 | if not _cn_conf: 240 | with open(CONTROLNET_FILE_PATH, "r", encoding="utf-8") as f: 241 | _cn_conf = json.load(f) 242 | 243 | return _cn_conf 244 | 245 | def get_controlnet_setting(name): 246 | c = get_cn_config_dict() 247 | return c[name] 248 | 249 | ####################################################################### 250 | ## lora dir 251 | 252 | LORA_DIR_FILE_PATH = "lora_dir.json" 253 | _lora_conf = {} 254 | 255 | def get_lora_config_dict(): 256 | global _lora_conf 257 | if not _lora_conf: 258 | with open(LORA_DIR_FILE_PATH, "r", encoding="utf-8") as f: 259 | _lora_conf = json.load(f) 260 | 261 | return _lora_conf 262 | 263 | def config_get_item_lora_dir_path(): 264 | c = get_lora_config_dict() 265 | env = config_get_current_lora_dir_env() 266 | return Path(c[env]["item_dir_path"]) 267 | 268 | def config_get_pose_lora_dir_path(): 269 | c = get_lora_config_dict() 270 | env = config_get_current_lora_dir_env() 271 | return Path(c[env]["pose_dir_path"]) 272 | 273 | def config_get_style_lora_dir_path(): 274 | c = get_lora_config_dict() 275 | env = config_get_current_lora_dir_env() 276 | return Path(c[env]["style_dir_path"]) 277 | 278 | def config_get_character_lora_dir_path(): 279 | c = get_lora_config_dict() 280 | env = config_get_current_lora_dir_env() 281 | return Path(c[env]["character_dir_path"]) 282 | 283 | def config_get_lora_dir_env_list(): 284 | c = get_lora_config_dict() 285 | logger.info(f"{c=}") 286 | return list(c.keys()) 287 | 288 | 289 | 290 | ####################################################################### 291 | ## preset tags 292 | 293 | PRESET_TAGS_FILE_PATH = "preset_tags.json" 294 | _preset_tags_conf = {} 295 | 296 | def get_preset_tags_config_dict(): 297 | global _preset_tags_conf 298 | if not _preset_tags_conf: 299 | with open(PRESET_TAGS_FILE_PATH, "r", encoding="utf-8") as f: 300 | _preset_tags_conf = json.load(f) 301 | 302 | return _preset_tags_conf 303 | 304 | def config_get_preset_tags_info(preset_name): 305 | c = get_preset_tags_config_dict() 306 | return c[preset_name] 307 | 308 | 309 | 310 | ####################################################################### 311 | 312 | 313 | def get_time_str(): 314 | return datetime.now().strftime("%Y%m%d_%H%M%S") 315 | 316 | 317 | _file_list_cache={} 318 | 319 | def select_one_file(dir_path:Path, suffixes, is_random=True): 320 | 321 | key = dir_path 322 | if key not in _file_list_cache: 323 | file_list = [p for p in dir_path.glob("**/*")] 324 | file_list = [p for p in file_list if p.suffix in suffixes] 325 | if len(file_list) == 0: 326 | raise ValueError(f"file not found in {dir_path=}") 327 | if is_random: 328 | random.shuffle(file_list) 329 | _file_list_cache[key] = file_list 330 | 331 | item = _file_list_cache[key].pop(0) 332 | if len(_file_list_cache[key]) == 0: 333 | _file_list_cache.pop(key) 334 | return item 335 | 336 | 337 | _video_cache={} 338 | 339 | def select_frame(movie_path:Path, interval_sec:float): 340 | import av 341 | key = movie_path 342 | if key not in _video_cache: 343 | video = av.open( movie_path ) 344 | _video_cache[key] = [video, 0] 345 | 346 | video = _video_cache[key][0] 347 | cur_sec = _video_cache[key][1] 348 | 349 | stream = video.streams.video[0] 350 | offset = int(cur_sec // stream.time_base) 351 | logger.info(f"{offset=}") 352 | 353 | video.seek( 354 | offset = offset, 355 | any_frame=False, 356 | backward=True, 357 | stream=stream 358 | ) 359 | 360 | while True: 361 | frame = next(video.decode(video=0), None) 362 | if frame is None: 363 | raise ValueError(f"seek failed. {movie_path=} {interval_sec=} {cur_sec=}") 364 | 365 | logger.debug(f"seeking :{frame.time}") 366 | if frame.time >= cur_sec: 367 | break 368 | 369 | image = Image.fromarray( frame.to_ndarray(format="rgb24") ) 370 | 371 | cur_sec += interval_sec 372 | if float(stream.duration * stream.time_base) <= cur_sec: 373 | cur_sec = 0 374 | _video_cache[key][1] = cur_sec 375 | 376 | return image 377 | 378 | 379 | def clear_video_cache(): 380 | for key in list(_video_cache.keys()): 381 | item = _video_cache.pop(key) 382 | item[0].close() 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | ######################################################### 392 | # from https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/infotext_utils.py 393 | 394 | re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' 395 | re_param = re.compile(re_param_code) 396 | re_imagesize = re.compile(r"^(\d+)x(\d+)$") 397 | 398 | def parse_generation_parameters(x: str, skip_fields: list[str] | None = None): 399 | """parses generation parameters string, the one you see in text field under the picture in UI: 400 | ``` 401 | girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate 402 | Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing 403 | Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b 404 | ``` 405 | 406 | returns a dict with field values 407 | """ 408 | 409 | def unquote(text): 410 | if len(text) == 0 or text[0] != '"' or text[-1] != '"': 411 | return text 412 | 413 | try: 414 | return json.loads(text) 415 | except Exception: 416 | return text 417 | 418 | 419 | if skip_fields is None: 420 | skip_fields = [] 421 | 422 | res = {} 423 | 424 | prompt = "" 425 | negative_prompt = "" 426 | 427 | done_with_prompt = False 428 | 429 | *lines, lastline = x.strip().split("\n") 430 | if len(re_param.findall(lastline)) < 3: 431 | lines.append(lastline) 432 | lastline = '' 433 | 434 | for line in lines: 435 | line = line.strip() 436 | if line.startswith("Negative prompt:"): 437 | done_with_prompt = True 438 | line = line[16:].strip() 439 | if done_with_prompt: 440 | negative_prompt += ("" if negative_prompt == "" else "\n") + line 441 | else: 442 | prompt += ("" if prompt == "" else "\n") + line 443 | 444 | for k, v in re_param.findall(lastline): 445 | try: 446 | if v[0] == '"' and v[-1] == '"': 447 | v = unquote(v) 448 | 449 | m = re_imagesize.match(v) 450 | if m is not None: 451 | res[f"{k}-1"] = m.group(1) 452 | res[f"{k}-2"] = m.group(2) 453 | else: 454 | res[k] = v 455 | except Exception: 456 | print(f"Error parsing \"{k}: {v}\"") 457 | 458 | res["Prompt"] = prompt 459 | res["Negative prompt"] = negative_prompt 460 | 461 | # Missing CLIP skip means it was set to 1 (the default) 462 | if "Clip skip" not in res: 463 | res["Clip skip"] = "1" 464 | 465 | hypernet = res.get("Hypernet", None) 466 | if hypernet is not None: 467 | res["Prompt"] += f"""""" 468 | 469 | if "Hires resize-1" not in res: 470 | res["Hires resize-1"] = 0 471 | res["Hires resize-2"] = 0 472 | 473 | if "Hires sampler" not in res: 474 | res["Hires sampler"] = "Use same sampler" 475 | 476 | if "Hires schedule type" not in res: 477 | res["Hires schedule type"] = "Use same scheduler" 478 | 479 | if "Hires checkpoint" not in res: 480 | res["Hires checkpoint"] = "Use same checkpoint" 481 | 482 | if "Hires prompt" not in res: 483 | res["Hires prompt"] = "" 484 | 485 | if "Hires negative prompt" not in res: 486 | res["Hires negative prompt"] = "" 487 | 488 | if "Mask mode" not in res: 489 | res["Mask mode"] = "Inpaint masked" 490 | 491 | if "Masked content" not in res: 492 | res["Masked content"] = 'original' 493 | 494 | if "Inpaint area" not in res: 495 | res["Inpaint area"] = "Whole picture" 496 | 497 | if "Masked area padding" not in res: 498 | res["Masked area padding"] = 32 499 | 500 | 501 | # Missing RNG means the default was set, which is GPU RNG 502 | if "RNG" not in res: 503 | res["RNG"] = "GPU" 504 | 505 | if "Schedule type" not in res: 506 | res["Schedule type"] = "Automatic" 507 | 508 | if "Schedule max sigma" not in res: 509 | res["Schedule max sigma"] = 0 510 | 511 | if "Schedule min sigma" not in res: 512 | res["Schedule min sigma"] = 0 513 | 514 | if "Schedule rho" not in res: 515 | res["Schedule rho"] = 0 516 | 517 | if "VAE Encoder" not in res: 518 | res["VAE Encoder"] = "Full" 519 | 520 | if "VAE Decoder" not in res: 521 | res["VAE Decoder"] = "Full" 522 | 523 | if "FP8 weight" not in res: 524 | res["FP8 weight"] = "Disable" 525 | 526 | if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable": 527 | res["Cache FP16 weight for LoRA"] = False 528 | 529 | if "Refiner switch by sampling steps" not in res: 530 | res["Refiner switch by sampling steps"] = False 531 | 532 | for key in skip_fields: 533 | res.pop(key, None) 534 | 535 | return res 536 | 537 | 538 | 539 | #################################################################### 540 | def crop_mask(im:Image.Image, th): 541 | im_array = np.array(im) 542 | coords = np.argwhere(im_array > th) 543 | x_min, y_min = coords.min(axis=0) 544 | x_max, y_max = coords.max(axis=0) 545 | cropped = im_array[x_min:x_max+1, y_min:y_max+1] 546 | return Image.fromarray(cropped) 547 | 548 | def get_box_of_mask(im:Image.Image, th): 549 | im_array = np.array(im) 550 | coords = np.argwhere(im_array > th) 551 | x_min, y_min = coords.min(axis=0) 552 | x_max, y_max = coords.max(axis=0) 553 | return (x_min, x_max+1), (y_min, y_max+1) 554 | 555 | def get_center_of_mask(im:Image.Image, th): 556 | (x0,x1),(y0,y1) = get_box_of_mask(im,th) 557 | return ((x0+x1)//2 , (y0+y1)//2) 558 | 559 | def create_focus_image(im:Image.Image, pos, scale): 560 | w, h = im.size 561 | logger.info(f"{scale=}") 562 | 563 | scale = float(scale) 564 | 565 | im_array = np.array(im) 566 | 567 | if pos: 568 | cx = pos[1] 569 | cy = pos[0] 570 | else: 571 | cx = w/2 572 | cy = h/2 573 | 574 | logger.info(f"{scale=}") 575 | if scale > 1.0: 576 | logger.info(f"scale > 1.0") 577 | cxmin = (w/scale)/2 578 | cxmax = w- (w/scale)/2 579 | 580 | logger.info(f"cxmin:{cxmin} cxmax:{cxmax}") 581 | 582 | cx = min(max(cx, cxmin), cxmax) 583 | 584 | cymin = (h/scale)/2 585 | cymax = h- (h/scale)/2 586 | 587 | logger.info(f"cymin:{cymin} cymax:{cymax}") 588 | 589 | cy = min(max(cy, cymin), cymax) 590 | 591 | 592 | logger.info(f"x:{pos[1]} y:{pos[0]}") 593 | logger.info(f"focus to {cx=} {cy=}") 594 | 595 | M = cv2.getRotationMatrix2D((float(cx), float(cy)), 0, float(scale)) 596 | expand_im = cv2.warpAffine(im_array, M, (w, h)) 597 | 598 | return Image.fromarray(expand_im) 599 | 600 | 601 | 602 | 603 | #################################################################### 604 | 605 | _thumb_cache = {} 606 | 607 | def get_thumb(path:Path, size): 608 | global _thumb_cache 609 | 610 | def im_2_b64(image): 611 | from io import BytesIO 612 | import base64 613 | buff = BytesIO() 614 | image.convert('RGB').save(buff, format="JPEG") 615 | img_str = base64.b64encode(buff.getvalue()) 616 | return img_str 617 | 618 | key = (path,size) 619 | if key not in _thumb_cache: 620 | thumb = Image.open(path) 621 | thumb.thumbnail(size=size) 622 | _thumb_cache[key] = im_2_b64(thumb) 623 | 624 | return _thumb_cache[key].decode('ascii') 625 | 626 | 627 | #################################################################### 628 | 629 | def config_restore_files_if_needed(): 630 | config_files = [ 631 | "config.json", 632 | "preset_tags.json", 633 | "lora_dir.json", 634 | "controlnet.json", 635 | ] 636 | 637 | for c in config_files: 638 | if Path(c).is_file() == False: 639 | shutil.copy( Path(f"default_config/default_{c}"), c) 640 | 641 | 642 | #################################################################### 643 | def get_image_file_list(img_dir): 644 | img_list = [p for p in Path(img_dir).glob("*") if re.search(r'.*\.(jpg|png|webp)', str(p))] 645 | return sorted(img_list) 646 | 647 | 648 | -------------------------------------------------------------------------------- /sd_batch_runner/generate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from pathlib import Path 5 | from datetime import datetime 6 | import math 7 | import random 8 | from copy import deepcopy 9 | from PIL import Image, PngImagePlugin 10 | from enum import Enum 11 | import shutil 12 | import traceback 13 | 14 | import webuiapi 15 | 16 | from sd_batch_runner.util import * 17 | from sd_batch_runner.lora import Lora,LoraType 18 | from sd_batch_runner.sdwebui_temp_fix import ControlNetUnit2 19 | 20 | logger = logging.getLogger(__name__) 21 | logger.setLevel(logging.INFO) 22 | formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)s %(funcName)s [%(levelname)s]: %(message)s") 23 | handler = logging.StreamHandler() 24 | handler.setFormatter(formatter) 25 | logger.addHandler(handler) 26 | 27 | 28 | 29 | SD_HOST = "127.0.0.1" 30 | SD_PORT = 7860 31 | 32 | 33 | 34 | 35 | DEFAULT_GENERATION_COMMON_SETTING={ 36 | "prompt":{ 37 | "character_lora" : "@random" 38 | } 39 | } 40 | 41 | DEFAULT_GENERATION_SEQ_SETTING = [{ 42 | "type":"txt2img", 43 | 44 | }] 45 | 46 | 47 | class RandomPicker(): 48 | def __init__(self, common_rule, seq_rule): 49 | def to_list(a): 50 | return a if type(a) == list else [a,None] 51 | 52 | self.common_rule = to_list(common_rule) 53 | self.seq_rule = [to_list(a) for a in seq_rule] 54 | self._validate() 55 | 56 | if self.common_rule[0] in ("@random_once", "@random_per_seq", "@random"): 57 | #self.common_value = self._select_one() 58 | self.common_value = None 59 | else: 60 | self.common_value = self._select( self.common_rule[0], self.common_rule[1] ) 61 | 62 | 63 | self.seq_value = [] 64 | for sr in self.seq_rule: 65 | if sr[0] in ("@random_once", "@random_per_seq", "@random"): 66 | #self.seq_value.append(self._select_one()) 67 | self.seq_value.append( None ) 68 | else: 69 | self.seq_value.append(self._select( sr[0], sr[1] )) 70 | 71 | def _validate(self): 72 | for i in range(len(self.seq_rule)): 73 | if self.seq_rule[i][0] == "@random_per_seq": 74 | self.seq_rule[i][0] = "@random" 75 | 76 | def pick(self, index): 77 | 78 | def update(rule, index, v): 79 | if rule[0] == "@random_once": 80 | if v is not None: 81 | return v 82 | else: 83 | return self._select_one( rule[1] ) 84 | elif rule[0] == "@random_per_seq": 85 | if index == 0: 86 | return self._select_one( rule[1] ) 87 | else: 88 | return v 89 | elif rule[0] == "@random": 90 | return self._select_one( rule[1] ) 91 | 92 | return self._select( rule[0], rule[1] ) 93 | 94 | self.common_value = update( self.common_rule, index, self.common_value) 95 | 96 | self.seq_value[index] = update( self.seq_rule[index], index, self.seq_value[index]) 97 | 98 | seq_v = self.seq_value[index] 99 | 100 | if seq_v is None: 101 | return self.common_value 102 | 103 | return seq_v 104 | 105 | class SeedPicker(RandomPicker): 106 | def __init__(self, common_rule, seq_rule): 107 | super().__init__(common_rule, seq_rule) 108 | 109 | def _select_one(self, opt): 110 | return int(random.randrange(4294967294)) 111 | 112 | def _select(self, v, opt): 113 | if v == None: 114 | return None 115 | return int(v) 116 | 117 | class LoraPicker(RandomPicker): 118 | def __init__(self, lora_type:LoraType, common_rule, seq_rule): 119 | self.lora_type = lora_type 120 | super().__init__(common_rule, seq_rule) 121 | 122 | def _adjust_weight(self, picked, opt): 123 | if picked is not None: 124 | if opt is not None: 125 | stem, weight, trigger = picked 126 | picked = [stem, weight * opt, trigger] 127 | return picked 128 | 129 | def _select_one(self, opt): 130 | filter = "" 131 | lora_str = opt 132 | if type(opt) == list: 133 | if len(opt) > 1: 134 | filter = opt[1] 135 | lora_str = opt[0] 136 | 137 | picked = Lora.select_one(self.lora_type, filter) 138 | return self._adjust_weight(picked, lora_str) 139 | 140 | def _select(self, v, opt): 141 | picked = Lora.select(self.lora_type, v) 142 | return self._adjust_weight(picked, opt) 143 | 144 | def get_type(self): 145 | return self.lora_type 146 | 147 | class LoraAllPicker(): 148 | def __init__(self, common_rule, seq_rule): 149 | def to_list(a): 150 | return a if type(a) == list else [a,None] 151 | def validate(a): 152 | return [] if a == None else a 153 | 154 | common_rule = validate(common_rule) 155 | seq_rule = [validate(s) for s in seq_rule] 156 | 157 | self.common_rule = [to_list(c) for c in common_rule] 158 | self.seq_rule = [[to_list(a) for a in seq] for seq in seq_rule] 159 | 160 | self.common_value = [self._select( c[0], c[1] ) for c in self.common_rule] 161 | self.seq_value = [[self._select( s[0], s[1] ) for s in seq] for seq in self.seq_rule] 162 | 163 | logger.info(f"{self.common_value=}") 164 | logger.info(f"{self.seq_value=}") 165 | 166 | def _adjust_weight(self, picked, opt): 167 | if picked is not None: 168 | if opt is not None: 169 | stem, weight, trigger = picked 170 | picked = [stem, weight * opt, trigger] 171 | return picked 172 | 173 | def _select(self, v, opt): 174 | picked = Lora.select(LoraType.All, v) 175 | return self._adjust_weight(picked, opt) 176 | 177 | def pick(self, index): 178 | seq_v = self.seq_value[index] 179 | if seq_v in (None, []): 180 | return self.common_value 181 | 182 | return seq_v 183 | 184 | 185 | class RandomPromptPicker(RandomPicker): 186 | def __init__(self, sd, common_rule, seq_rule): 187 | def convert(a): 188 | table={ 189 | "once" : "@random_once", 190 | "per_seq" : "@random_per_seq", 191 | "any_time" : "@random" 192 | } 193 | return table.get(a, None) 194 | 195 | common_rule = convert(common_rule) 196 | seq_rule = [convert(s) for s in seq_rule] 197 | 198 | self.sd = sd 199 | super().__init__(common_rule, seq_rule) 200 | 201 | def pick(self, index, gen_setting): 202 | self.gen_setting = gen_setting 203 | result = super().pick(index) 204 | 205 | if result is None: 206 | return [None,None] 207 | 208 | lines = result.strip().split("Negative prompt") 209 | prompt = lines[0] 210 | if len(lines) > 1: 211 | negative_prompt = lines[1] 212 | negative_prompt = negative_prompt[16:].strip() 213 | else: 214 | negative_prompt = "" 215 | 216 | return [ prompt , negative_prompt] 217 | 218 | def _select_one(self, opt): 219 | return self.sd.prompt_gen(self.gen_setting) 220 | 221 | def _select(self, v, opt): 222 | return v 223 | 224 | 225 | 226 | class PresetTags(): 227 | def __init__(self, common_rule, seq_rule): 228 | self.common_rule = common_rule 229 | self.seq_rule = seq_rule 230 | def _apply(self, preset_name, prompt, neg_prompt): 231 | preset_info = config_get_preset_tags_info(preset_name) 232 | 233 | if preset_info["is_footer"]: 234 | prompt_list = [ prompt, preset_info["prompt"] ] 235 | neg_prompt_list = [ neg_prompt, preset_info["negative_prompt"] ] 236 | else: 237 | prompt_list = [ preset_info["prompt"], prompt ] 238 | neg_prompt_list = [ preset_info["negative_prompt"], neg_prompt ] 239 | 240 | if preset_info["prompt"]: 241 | prompt = ",".join(prompt_list) 242 | if preset_info["negative_prompt"]: 243 | neg_prompt = ",".join(neg_prompt_list) 244 | 245 | return prompt, neg_prompt 246 | 247 | def apply(self,index, prompt, neg_prompt): 248 | seq = self.seq_rule[index] 249 | cur = seq if seq else self.common_rule 250 | 251 | if cur: 252 | for tag in cur: 253 | prompt, neg_prompt = self._apply(tag, prompt, neg_prompt) 254 | 255 | return prompt, neg_prompt 256 | 257 | 258 | class InputSource(): 259 | 260 | static_seq = 0 261 | static_instance_map={} 262 | 263 | class InnerInput(): 264 | def __init__(self, dir_path:Path, rule): 265 | self.seq = -1 266 | self.dir_path = dir_path 267 | self.rule = rule 268 | self.cache = None #[seq,index,img] 269 | 270 | def pick(self,index): 271 | if self.cache: 272 | if self.cache[0] == -1: 273 | pass 274 | elif self.cache[0] != InputSource.static_seq: 275 | self.cache = None 276 | elif self.rule == "@random": 277 | if self.cache[1] != index: 278 | self.cache = None 279 | 280 | if self.cache is None: 281 | img = self.get_image() 282 | if self.rule == "@random_once": 283 | self.cache = [-1,-1,img] 284 | else: 285 | self.cache = [InputSource.static_seq,index,img] 286 | 287 | return self.cache[2] 288 | 289 | class InnerInputDir(InnerInput): 290 | def __init__(self, dir_path:Path, rule, is_random): 291 | super().__init__(dir_path, rule) 292 | self.is_random = is_random 293 | def get_image(self): 294 | #return Image.open(select_one_file(self.dir_path, [".jpg",".png",".JPG",".PNG"], is_random=self.is_random)) 295 | return select_one_file(self.dir_path, [".jpg",".png",".JPG",".PNG"], is_random=self.is_random) 296 | 297 | class InnerInputMov(InnerInput): 298 | def __init__(self, dir_path:Path, interval): 299 | super().__init__(dir_path, "@random") 300 | self.interval = interval 301 | def get_image(self): 302 | return select_frame(self.dir_path, self.interval) 303 | 304 | 305 | @classmethod 306 | def update_seq(cls, seq): 307 | InputSource.static_seq = seq 308 | 309 | @classmethod 310 | def pick(cls, index, dir_path:Path, is_random=True, rule="@random"): 311 | key = dir_path 312 | if key not in InputSource.static_instance_map: 313 | InputSource.static_instance_map[dir_path] = InputSource.InnerInputDir(dir_path, rule, is_random) 314 | 315 | return InputSource.static_instance_map[dir_path].pick(index) 316 | 317 | @classmethod 318 | def pick_m(cls, index, dir_path:Path, interval): 319 | key = dir_path 320 | if key not in InputSource.static_instance_map: 321 | InputSource.static_instance_map[dir_path] = InputSource.InnerInputMov(dir_path, interval) 322 | 323 | return InputSource.static_instance_map[dir_path].pick(index) 324 | 325 | 326 | class PromptGenerator(): 327 | def __init__(self, common_info, seq_info, sd): 328 | self.common_prompt =common_prompt = common_info.get("prompt", {}) 329 | self.seq_prompt = seq_prompt = [ seq.get("prompt", {}) for seq in seq_info ] 330 | 331 | self.lora_order = common_prompt.get("lora_order", []) 332 | 333 | common_rule = common_prompt.get("character_lora", None) 334 | seq_rule = [ seq.get("character_lora", None) for seq in seq_prompt ] 335 | logger.info(f"CharacterLora {common_rule=} {seq_rule=}") 336 | self.character_lora = [LoraPicker( LoraType.Character, common_rule, seq_rule )] 337 | 338 | common_rule = common_prompt.get("character_lora2", None) 339 | seq_rule = [ seq.get("character_lora2", None) for seq in seq_prompt ] 340 | logger.info(f"CharacterLora 2 {common_rule=} {seq_rule=}") 341 | self.character_lora.append(LoraPicker( LoraType.Character, common_rule, seq_rule )) 342 | 343 | common_rule = common_prompt.get("style_lora", None) 344 | seq_rule = [ seq.get("style_lora", None) for seq in seq_prompt ] 345 | logger.info(f"StyleLora {common_rule=} {seq_rule=}") 346 | self.style_lora = [LoraPicker( LoraType.Style, common_rule, seq_rule )] 347 | 348 | common_rule = common_prompt.get("style_lora2", None) 349 | seq_rule = [ seq.get("style_lora2", None) for seq in seq_prompt ] 350 | logger.info(f"StyleLora 2 {common_rule=} {seq_rule=}") 351 | self.style_lora.append(LoraPicker( LoraType.Style, common_rule, seq_rule )) 352 | 353 | common_rule = common_prompt.get("pose_lora", None) 354 | seq_rule = [ seq.get("pose_lora", None) for seq in seq_prompt ] 355 | logger.info(f"PoseLora {common_rule=} {seq_rule=}") 356 | self.pose_lora = [LoraPicker( LoraType.Pose, common_rule, seq_rule )] 357 | 358 | common_rule = common_prompt.get("pose_lora2", None) 359 | seq_rule = [ seq.get("pose_lora2", None) for seq in seq_prompt ] 360 | logger.info(f"PoseLora 2 {common_rule=} {seq_rule=}") 361 | self.pose_lora.append(LoraPicker( LoraType.Pose, common_rule, seq_rule )) 362 | 363 | common_rule = common_prompt.get("item_lora", None) 364 | seq_rule = [ seq.get("item_lora", None) for seq in seq_prompt ] 365 | logger.info(f"ItemLora {common_rule=} {seq_rule=}") 366 | self.item_lora = [LoraPicker( LoraType.Item, common_rule, seq_rule )] 367 | 368 | common_rule = common_prompt.get("item_lora2", None) 369 | seq_rule = [ seq.get("item_lora2", None) for seq in seq_prompt ] 370 | logger.info(f"ItemLora 2 {common_rule=} {seq_rule=}") 371 | self.item_lora.append(LoraPicker( LoraType.Item, common_rule, seq_rule )) 372 | 373 | common_rule = common_prompt.get("additional_loras", None) 374 | seq_rule = [ seq.get("additional_loras", None) for seq in seq_prompt ] 375 | self.add_lora = LoraAllPicker( common_rule, seq_rule ) 376 | 377 | common_rule = common_prompt.get("preset_tags", None) 378 | seq_rule = [ seq.get("preset_tags", None) for seq in seq_prompt ] 379 | self.preset_tags = PresetTags( common_rule, seq_rule ) 380 | 381 | self.common_prompt_gen = common_prompt_gen = common_info.get("prompt_gen", {}) 382 | self.seq_prompt_gen = seq_prompt_gen = [ seq.get("prompt_gen", {}) for seq in seq_info ] 383 | 384 | common_rule = common_prompt_gen.get("type", None) 385 | seq_rule = [ seq.get("type", None) for seq in seq_prompt_gen ] 386 | 387 | self.random_prompt_picker = RandomPromptPicker(sd, common_rule, seq_rule) 388 | 389 | def _append_random_prompt(self, index, prompt, neg_prompt): 390 | seq_setting = self.seq_prompt_gen[index] 391 | 392 | gen_setting = config_get_default_prompt_gen_setting() 393 | for s in self.common_prompt_gen: 394 | gen_setting[s] = self.common_prompt_gen[s] 395 | 396 | for s in seq_setting: 397 | gen_setting[s] = seq_setting[s] 398 | is_footer = gen_setting.get("is_footer", True) 399 | 400 | for k in ("type","is_footer"): 401 | if k in gen_setting: 402 | gen_setting.pop(k) 403 | 404 | pro,neg = self.random_prompt_picker.pick(index, gen_setting) 405 | 406 | if pro: 407 | if is_footer: 408 | prompt = ", ".join([prompt, pro]) 409 | neg_prompt = ", ".join([neg_prompt, neg]) 410 | else: 411 | prompt = ", ".join([pro, prompt]) 412 | neg_prompt = ", ".join([neg, neg_prompt]) 413 | 414 | return prompt, neg_prompt 415 | 416 | 417 | def generate(self, index, org_neg, base_part=True, lora_part=True, random_part=True): 418 | seq_prompt = self.seq_prompt[index] 419 | 420 | header = self.common_prompt.get("header", None) 421 | seq_header = seq_prompt.get("header", None) 422 | if seq_header: 423 | header = seq_header 424 | 425 | footer = self.common_prompt.get("footer", None) 426 | seq_footer = seq_prompt.get("footer", None) 427 | if seq_footer: 428 | footer = seq_footer 429 | 430 | prompt = "" 431 | if base_part: 432 | if header: 433 | prompt += header 434 | 435 | if lora_part: 436 | picked_list = [] 437 | for i,lora in enumerate(( self.style_lora[0], self.style_lora[1], self.character_lora[0], self.character_lora[1], self.pose_lora[0], self.pose_lora[1], self.item_lora[0], self.item_lora[1] )): 438 | picked = lora.pick(index) 439 | if picked: 440 | picked_list.append((i, picked, lora.get_type())) 441 | 442 | lora_w_rate = max( 1.0 * 0.9 ** (len(picked_list)-1) , 0.75) 443 | 444 | label = ["style_lora","style_lora2","character_lora","character_lora2","pose_lora","pose_lora2","item_lora","item_lora2"] 445 | 446 | if self.lora_order: 447 | lora_order = [label.index(o) for o in self.lora_order] 448 | custom_order = {c:i for i,c in enumerate(lora_order)} 449 | picked_list.sort(key=lambda c:custom_order.get(c[0], len(custom_order))) 450 | 451 | for i, picked, lora_type in picked_list: 452 | logger.info(f"{label[i]} = {picked[0]}") 453 | lora_index = 1 if label[i].endswith("2") else 0 454 | lora_syntax = create_lora_syntax(lora_type, lora_index, picked[0], picked[1] * lora_w_rate) 455 | if picked[2]: 456 | prompt = ", ".join([prompt, lora_syntax, picked[2]]) 457 | else: 458 | prompt = ", ".join([prompt, lora_syntax]) 459 | 460 | add_loras = self.add_lora.pick(index) 461 | for picked in add_loras: 462 | logger.info(f"add : {picked[0]}") 463 | lora_syntax = create_lora_syntax(LoraType.All, 0, picked[0], picked[1] * lora_w_rate) 464 | if picked[2]: 465 | prompt = ", ".join([prompt, lora_syntax, picked[2]]) 466 | else: 467 | prompt = ", ".join([prompt, lora_syntax]) 468 | 469 | 470 | if base_part: 471 | if footer: 472 | prompt = ", ".join([prompt, footer]) 473 | 474 | neg_prompt = org_neg 475 | 476 | if random_part: 477 | prompt, neg_prompt = self._append_random_prompt(index, prompt, neg_prompt) 478 | 479 | if base_part: 480 | prompt, neg_prompt = self.preset_tags.apply(index, prompt, neg_prompt) 481 | 482 | return prompt, neg_prompt 483 | 484 | 485 | class GenerationType(str, Enum): 486 | Txt2Img = "txt2img" 487 | Img2Img = "img2img" 488 | Copy = "copy" 489 | 490 | 491 | 492 | class FocusUtility: 493 | def __init__(self, sd): 494 | self.sd = sd 495 | self.focus_cache = {} 496 | 497 | def get_focus_image_from_path(self, image_path, focus_target, scale): 498 | key = (image_path, focus_target, scale) 499 | if key not in self.focus_cache: 500 | 501 | org_img = Image.open(image_path) 502 | 503 | img = self.sd.create_focus_image(org_img, focus_target, scale) 504 | 505 | self.focus_cache[key] = img 506 | 507 | return self.focus_cache[key] 508 | 509 | def get_focus_image_from_image(self, org_img, key_index, focus_target, scale): 510 | key = (key_index, focus_target, scale) 511 | if key not in self.focus_cache: 512 | 513 | img = self.sd.create_focus_image(org_img, focus_target, scale) 514 | 515 | self.focus_cache[key] = img 516 | 517 | return self.focus_cache[key] 518 | 519 | def clear_cache(self): 520 | self.focus_cache = {} 521 | 522 | 523 | def create_lora_syntax(lora_type:LoraType, lora_index, stem, weight): 524 | enable_lbw = False 525 | 526 | if lora_type == LoraType.Character: 527 | enable_lbw = config_get_lbw_enable_character(lora_index) 528 | preset = config_get_lbw_preset_character(lora_index) 529 | sss_type = config_get_lbw_start_stop_step_character(lora_index) 530 | sss_val = config_get_lbw_start_stop_step_value_character(lora_index) 531 | elif lora_type == LoraType.Style: 532 | enable_lbw = config_get_lbw_enable_style(lora_index) 533 | preset = config_get_lbw_preset_style(lora_index) 534 | sss_type = config_get_lbw_start_stop_step_style(lora_index) 535 | sss_val = config_get_lbw_start_stop_step_value_style(lora_index) 536 | elif lora_type == LoraType.Pose: 537 | enable_lbw = config_get_lbw_enable_pose(lora_index) 538 | preset = config_get_lbw_preset_pose(lora_index) 539 | sss_type = config_get_lbw_start_stop_step_pose(lora_index) 540 | sss_val = config_get_lbw_start_stop_step_value_pose(lora_index) 541 | elif lora_type == LoraType.Item: 542 | enable_lbw = config_get_lbw_enable_item(lora_index) 543 | preset = config_get_lbw_preset_item(lora_index) 544 | sss_type = config_get_lbw_start_stop_step_item(lora_index) 545 | sss_val = config_get_lbw_start_stop_step_value_item(lora_index) 546 | 547 | if weight != 0: 548 | if enable_lbw: 549 | if sss_type in ("start","stop","step"): 550 | return f"" 551 | else: 552 | return f"" 553 | else: 554 | return f"" 555 | else: 556 | return "" 557 | 558 | 559 | class GenerationSeqSetting(): 560 | def __init__(self, generation_info, sd): 561 | self.gen_info = generation_info 562 | self.common = generation_info.get("common", DEFAULT_GENERATION_COMMON_SETTING) 563 | self.seq = generation_info.get("seq", DEFAULT_GENERATION_SEQ_SETTING) 564 | 565 | common_seed = self.common.get("seed", "@random_per_seq") 566 | seq_seed = [ seq.get("seed", None) for seq in self.seq ] 567 | self.seed = SeedPicker(common_seed, seq_seed) 568 | 569 | self.prompt = PromptGenerator(self.common, self.seq, sd) 570 | 571 | self.focus_util = FocusUtility(sd) 572 | 573 | def get_checkpoint_name(self): 574 | ck = config_get_default_checkpoint() 575 | gen_ck = self.common.get("checkpoint", "") 576 | if gen_ck: 577 | ck = gen_ck 578 | return ck 579 | 580 | def _extract_geninfo_from_image(self, index, image, ow_from_png, gen_setting): 581 | param = image.info.get('parameters',{}) 582 | if not param: 583 | raise ValueError(f"invalid gen source") 584 | 585 | param = parse_generation_parameters(param) 586 | 587 | ow_flag = ow_from_png.get("overwrite_steps", True) 588 | if ow_flag: 589 | gen_setting["steps"] = int(param.get("Steps", gen_setting["steps"])) 590 | 591 | ow_flag = ow_from_png.get("overwrite_sampler_name", True) 592 | if ow_flag: 593 | gen_setting["sampler_name"] = param.get("Sampler", gen_setting["sampler_name"]) 594 | 595 | ow_flag = ow_from_png.get("overwrite_scheduler", True) 596 | if ow_flag: 597 | gen_setting["scheduler"] = param.get("Schedule type", gen_setting["scheduler"]) 598 | 599 | ow_flag = ow_from_png.get("overwrite_cfg_scale", True) 600 | if ow_flag: 601 | gen_setting["cfg_scale"] = float(param.get("CFG scale", gen_setting["cfg_scale"])) 602 | 603 | ow_flag = ow_from_png.get("overwrite_seed", True) 604 | if ow_flag: 605 | gen_setting["seed"] = int(param.get("Seed", gen_setting["seed"])) 606 | 607 | ow_flag = ow_from_png.get("overwrite_width", True) 608 | if ow_flag: 609 | gen_setting["width"] = int(param.get("Size-1", gen_setting["width"])) 610 | 611 | ow_flag = ow_from_png.get("overwrite_height", True) 612 | if ow_flag: 613 | gen_setting["height"] = int(param.get("Size-2", gen_setting["height"])) 614 | 615 | ow_flag = ow_from_png.get("overwrite_prompt", True) 616 | if ow_flag: 617 | gen_setting["prompt"] = param.get("Prompt", "") 618 | add_lora = ow_from_png.get("add_lora", False) 619 | add_prompt_gen = ow_from_png.get("add_prompt_gen", False) 620 | if add_lora or add_prompt_gen: 621 | gen_setting["prompt"] += "," + self.prompt.generate(index, "", base_part=False, lora_part=add_lora, random_part=add_prompt_gen)[0] 622 | 623 | 624 | ow_flag = ow_from_png.get("overwrite_negative_prompt", True) 625 | if ow_flag: 626 | gen_setting["negative_prompt"] = param.get("Negative prompt", gen_setting["negative_prompt"]) 627 | 628 | 629 | return gen_setting 630 | 631 | 632 | def _overwrite_generation_setting(self, index, gen_setting): 633 | ow_from_png = deepcopy(config_get_default_overwrite_generation_setting()) 634 | ow_ow_from_png = self.common.get("overwrite_generation_setting", {}) 635 | if ow_ow_from_png: 636 | for o in ow_ow_from_png: 637 | ow_from_png[o] = ow_ow_from_png[o] 638 | ow_ow_from_png = self.seq[index].get("overwrite_generation_setting",{}) 639 | if ow_ow_from_png: 640 | for o in ow_ow_from_png: 641 | ow_from_png[o] = ow_ow_from_png[o] 642 | 643 | if ow_from_png: 644 | image_rule = ow_from_png.get("png_info", None) 645 | 646 | if image_rule == None: 647 | return gen_setting 648 | 649 | img = self.get_image_from_image_rule(index, image_rule) 650 | gen_setting = self._extract_geninfo_from_image(index, img, ow_from_png, gen_setting) 651 | 652 | return gen_setting 653 | 654 | 655 | 656 | def _create_setting(self, index): 657 | gen_type = GenerationType( self.seq[index].get("type","txt2img") ) 658 | is_txt2img = (gen_type == GenerationType.Txt2Img) 659 | 660 | gen_setting = deepcopy(config_get_default_generation_setting(is_txt2img)) 661 | ow_gen_setting = self.common.get("generation_setting",{}) 662 | if ow_gen_setting: 663 | for o in ow_gen_setting: 664 | gen_setting[o] = ow_gen_setting[o] 665 | ow_gen_setting = self.seq[index].get("generation_setting",{}) 666 | if ow_gen_setting: 667 | for o in ow_gen_setting: 668 | gen_setting[o] = ow_gen_setting[o] 669 | 670 | gen_setting["seed"] = self.seed.pick(index) 671 | 672 | gen_setting = self._overwrite_generation_setting(index, gen_setting) 673 | 674 | if gen_setting.get("prompt", None): 675 | pass 676 | else: 677 | pro,neg = self.prompt.generate(index, gen_setting["negative_prompt"]) 678 | gen_setting["prompt"] = pro 679 | gen_setting["negative_prompt"] = neg 680 | 681 | output_scale = self.seq[index].get("output_scale", None) 682 | if output_scale is not None: 683 | gen_setting["width"] = int(gen_setting["width"] * output_scale) 684 | gen_setting["height"] = int(gen_setting["height"] * output_scale) 685 | 686 | 687 | cn_setting = self.seq[index].get("controlnet", []) 688 | 689 | ad_setting = config_get_adetailer_setting() 690 | common_ad_setting = self.common.get("adetailer",[]) 691 | if common_ad_setting: 692 | ad_setting = common_ad_setting 693 | seq_ad_setting = self.seq[index].get("adetailer",[]) 694 | if seq_ad_setting: 695 | ad_setting = seq_ad_setting 696 | 697 | output_filename = self.seq[index].get("output_filename",None) 698 | 699 | return (gen_type, gen_setting, cn_setting, ad_setting, output_filename) 700 | 701 | def get_image_from_image_rule(self, index, image_rule): 702 | 703 | focus_target = focus_scale = None 704 | 705 | if type(image_rule) == list: 706 | if type(image_rule[0]) == list: 707 | focus_target = image_rule[1] 708 | focus_scale = image_rule[2] 709 | image_rule = image_rule[0] 710 | if len(image_rule) == 1: 711 | image_rule = image_rule[0] 712 | 713 | path_or_img = self.get_path_or_image_from_image_rule(index,image_rule) 714 | 715 | if isinstance(path_or_img, Image.Image): 716 | return path_or_img 717 | else: 718 | 719 | if focus_target: 720 | if isinstance(path_or_img, Path): 721 | return self.focus_util.get_focus_image_from_path(path_or_img, focus_target, focus_scale) 722 | elif type(path_or_img) == int: 723 | org_img = self.result[ path_or_img ] 724 | return self.focus_util.get_focus_image_from_image(org_img, path_or_img, focus_target, focus_scale) 725 | else: 726 | raise ValueError(f"unknown format {path_or_img=}") 727 | else: 728 | 729 | if isinstance(path_or_img, Path): 730 | return Image.open(path_or_img) 731 | elif type(path_or_img) == int: 732 | return self.result[ path_or_img ] 733 | else: 734 | raise ValueError(f"unknown format {path_or_img=}") 735 | 736 | 737 | def get_path_or_image_from_image_rule(self, index, image_rule): 738 | if type(image_rule) == int: 739 | return image_rule #self.result[ image_rule ] 740 | elif type(image_rule) == str: 741 | image_path = Path(image_rule) 742 | if image_path.is_file(): 743 | return Path(image_rule) 744 | elif image_path.is_dir(): 745 | return InputSource.pick(index, image_path) 746 | else: 747 | raise ValueError(f"unknown rule {image_rule=}") 748 | elif type(image_rule) == list: 749 | image_path = Path(image_rule[0]) 750 | if image_path.is_dir(): 751 | rule = "@random" 752 | if len(image_rule) > 2: 753 | rule = image_rule[2] 754 | is_random = True 755 | if len(image_rule) > 1: 756 | is_random = False if image_rule[1]=="sort" else True 757 | return InputSource.pick(index, image_path, is_random, rule) 758 | elif image_path.is_file() and image_path.suffix in (".mp4", ".MP4"): 759 | if type(image_rule[1]) in (int, float): 760 | return InputSource.pick_m(index, image_path, image_rule[1]) 761 | else: 762 | raise ValueError(f"unknown rule {image_rule[1]=} must be int or float") 763 | else: 764 | raise ValueError(f"unknown rule {image_rule=} {image_path} is invalid path") 765 | else: 766 | raise ValueError(f"unknown rule {image_rule=}") 767 | 768 | def get_input_image(self, index): 769 | input_image_rule = self.seq[index].get("input_image", 0) 770 | return self.get_image_from_image_rule(index, input_image_rule) 771 | 772 | def get_input_image_or_path(self, index): 773 | input_image_rule = self.seq[index].get("input_image", 0) 774 | # ignore focus rule 775 | if type(input_image_rule) == list: 776 | if type(input_image_rule[0]) == list: 777 | input_image_rule = input_image_rule[0] 778 | if len(input_image_rule) == 1: 779 | input_image_rule = input_image_rule[0] 780 | 781 | return self.get_path_or_image_from_image_rule(index, input_image_rule) 782 | 783 | def set_result(self, index, image): 784 | self.result[index] = image 785 | 786 | def set_cn_mask_target(self, index, mask): 787 | self.cn_mask_target[index] = mask 788 | 789 | def get_cn_mask_target(self, index): 790 | return self.cn_mask_target[index] 791 | 792 | def __getitem__(self, index): 793 | if index == 0: 794 | self.result = [None for n in range(len(self.seq))] 795 | self.cn_mask_target = [None for n in range(len(self.seq))] 796 | self.focus_util.clear_cache() 797 | 798 | logger.info(f"{index=}") 799 | if 0 <= index < len(self.seq): 800 | try: 801 | return self._create_setting(index) 802 | except Exception as e: 803 | logger.error(traceback.format_exc()) 804 | #exit() 805 | raise e 806 | else: 807 | raise IndexError(f"invalid {index=}") 808 | 809 | 810 | 811 | class SDGen: 812 | def __init__( self, host=SD_HOST, port=SD_PORT, output_dir_path=Path("output") ): 813 | self.api = webuiapi.WebUIApi(host=host, port=port) 814 | self.segif = webuiapi.SegmentAnythingInterface(self.api) 815 | self.output_dir_path = output_dir_path 816 | 817 | def get_checkpoints(self): 818 | return self.api.util_get_model_names() 819 | 820 | def get_latent_upscale_modes(self): 821 | return sorted([s['name'] for s in self.api.get_latent_upscale_modes()]) 822 | 823 | def get_samplers(self): 824 | return self.api.util_get_sampler_names() 825 | 826 | def get_schedulers(self): 827 | return self.api.util_get_scheduler_names() 828 | 829 | def get_sam_models(self): 830 | return self.segif.get_sam_models() 831 | 832 | def get_loras(self): 833 | return self.api.get_loras() 834 | 835 | def get_controlnets(self): 836 | return self.api.controlnet_model_list(), self.api.controlnet_module_list() 837 | 838 | def set_checkpoint(self, name): 839 | self.api.util_set_model(name) 840 | 841 | def prompt_gen(self, prompt_gen_setting): 842 | prompt_gen_setting["batch_count"] = 1 843 | prompt_gen_setting["batch_size"] = 1 844 | result = self.api.prompt_gen(**prompt_gen_setting) 845 | return result[0] 846 | 847 | def setup_seq( self, generation_info ): 848 | self.gen_seq = GenerationSeqSetting(generation_info, self) 849 | 850 | def get_seq_length(self): 851 | return len(self.gen_seq.seq) 852 | 853 | def get_mask_of_image(self, img, mask_target): 854 | 855 | if type(mask_target) == int: 856 | return self.gen_seq.get_cn_mask_target(mask_target) 857 | 858 | try: 859 | sam_result = self.segif.sam_predict( 860 | image=img, 861 | sam_model_name = config_get_segment_anything_sam_model_name(), 862 | dino_enabled=True, 863 | dino_text_prompt=mask_target, 864 | dino_model_name= config_get_segment_anything_dino_model_name() 865 | ) 866 | 867 | dilation_result = self.segif.dilate_mask( 868 | image=img, 869 | mask=sam_result.masks[0], # using the first mask from the SAM prediction 870 | dilate_amount=30 871 | ) 872 | except Exception as e: 873 | return None 874 | 875 | return dilation_result.mask.convert('RGB') 876 | 877 | def create_focus_image(self, img, focus_target, scale): 878 | 879 | try: 880 | sam_result = self.segif.sam_predict( 881 | image=img, 882 | sam_model_name = config_get_segment_anything_sam_model_name(), 883 | dino_enabled=True, 884 | dino_text_prompt=focus_target, 885 | dino_model_name= config_get_segment_anything_dino_model_name() 886 | ) 887 | 888 | dilation_result = self.segif.dilate_mask( 889 | image=img, 890 | mask=sam_result.masks[0], # using the first mask from the SAM prediction 891 | dilate_amount=30 892 | ) 893 | 894 | focus_point = get_center_of_mask(dilation_result.mask, 0) 895 | except Exception as e: 896 | focus_point = None 897 | 898 | focus_img = create_focus_image(img, focus_point, scale) 899 | 900 | return focus_img 901 | 902 | def create_controlnet_units(self, index, cn_setting): 903 | units = [] 904 | for cn in cn_setting: 905 | name = cn["type"] 906 | s = get_controlnet_setting(name) 907 | for key in cn.keys(): 908 | if key not in ("type","image","mask","image_scale","cn_target"): 909 | s[key] = cn[key] 910 | 911 | #unit = webuiapi.ControlNetUnit(**s) 912 | unit = ControlNetUnit2(**s) 913 | img = cn.get("image", None) 914 | if img is not None: 915 | unit.image = self.gen_seq.get_image_from_image_rule(index, img) 916 | img_scale = cn.get("image_scale", None) 917 | if img_scale: 918 | new_size = (int(unit.image.size[0] * img_scale)//8 * 8, int(unit.image.size[1] * img_scale)//8 * 8) 919 | unit.image = unit.image.resize( new_size ) 920 | 921 | def is_valid_target(cn_t): 922 | if type(cn_t) in (int,float): 923 | return True 924 | else: 925 | return cn_t 926 | 927 | cn_target = cn.get("cn_target", None) 928 | if is_valid_target(cn_target) and (img is not None): 929 | unit.effective_region_mask = self.get_mask_of_image(unit.image, cn_target) 930 | self.gen_seq.set_cn_mask_target(index, unit.effective_region_mask) 931 | 932 | img = cn.get("mask", None) 933 | if img is not None: 934 | unit.mask = self.gen_seq.get_image_from_image_rule(index, img) 935 | 936 | units.append(unit) 937 | 938 | return units 939 | 940 | def create_adetailer_units(self, index, ad_setting): 941 | units = [] 942 | for ad in ad_setting: 943 | unit = webuiapi.ADetailer(**ad) 944 | units.append(unit) 945 | return units 946 | 947 | def _run( self, index, gen_type:GenerationType, gen_setting, cn_setting, ad_setting): 948 | 949 | logger.info(f"{index=}") 950 | logger.info(f"{gen_type=}") 951 | logger.info(f"{gen_setting=}") 952 | logger.info(f"{cn_setting=}") 953 | logger.info(f"{ad_setting=}") 954 | 955 | if gen_type == GenerationType.Copy: 956 | copy_src = self.gen_seq.get_input_image_or_path(index) 957 | if isinstance(copy_src, Path): 958 | img = Image.open(copy_src) 959 | self.gen_seq.set_result(index, img) 960 | result = copy_src 961 | else: 962 | raise ValueError(f"copy source(input_image) invalid {copy_src=}") 963 | else: 964 | 965 | cn_units = self.create_controlnet_units(index, cn_setting) 966 | ad_units = self.create_adetailer_units(index, ad_setting) 967 | 968 | alwayson_scripts = { 969 | #"Simple wildcards": [] 970 | } 971 | 972 | if gen_type == GenerationType.Txt2Img: 973 | result = self.api.txt2img( **gen_setting, alwayson_scripts=alwayson_scripts, controlnet_units=cn_units, adetailer=ad_units) 974 | else: 975 | images = [ self.gen_seq.get_input_image(index) ] 976 | result = self.api.img2img( images=images, mask_image=None, **gen_setting, alwayson_scripts=alwayson_scripts, controlnet_units=cn_units, adetailer=ad_units) 977 | 978 | self.gen_seq.set_result(index, result.images[0]) 979 | 980 | return result 981 | 982 | def generate( self, n = 1): 983 | 984 | ck = self.gen_seq.get_checkpoint_name() 985 | if ck: 986 | self.set_checkpoint( ck ) 987 | 988 | for i in range(n): 989 | 990 | InputSource.update_seq(i) 991 | 992 | for f,s in enumerate(self.gen_seq): 993 | result = self._run(f, s[0], s[1], s[2], s[3]) 994 | 995 | output_filename = s[4] 996 | if output_filename: 997 | output_path = self.output_dir_path / Path(str(i).zfill(5)) 998 | output_path.mkdir(parents=True, exist_ok=True) 999 | output_path = output_path/Path( output_filename ) 1000 | else: 1001 | output_path = self.output_dir_path / Path( f"{str(i).zfill(5)}_{str(f).zfill(5)}.png") 1002 | 1003 | if isinstance(result, Path): 1004 | shutil.copy(result, output_path) 1005 | else: 1006 | pnginfo = PngImagePlugin.PngInfo() 1007 | pnginfo.add_text("parameters", result.info['infotexts'][0]) 1008 | result.image.save( output_path, pnginfo=pnginfo) 1009 | 1010 | def generate_generator( self, n = 1): 1011 | 1012 | ck = self.gen_seq.get_checkpoint_name() 1013 | if ck: 1014 | self.set_checkpoint( ck ) 1015 | 1016 | for i in range(n): 1017 | 1018 | InputSource.update_seq(i) 1019 | 1020 | for f,s in enumerate(self.gen_seq): 1021 | result = self._run(f, s[0], s[1], s[2], s[3]) 1022 | 1023 | output_filename = s[4] 1024 | if output_filename: 1025 | output_path = self.output_dir_path / Path(str(i).zfill(5)) 1026 | output_path.mkdir(parents=True, exist_ok=True) 1027 | output_path = output_path/Path( output_filename ) 1028 | else: 1029 | output_path = self.output_dir_path / Path( f"{str(i).zfill(5)}_{str(f).zfill(5)}.png") 1030 | 1031 | if isinstance(result, Path): 1032 | shutil.copy(result, output_path) 1033 | else: 1034 | pnginfo = PngImagePlugin.PngInfo() 1035 | pnginfo.add_text("parameters", result.info['infotexts'][0]) 1036 | result.image.save( output_path, pnginfo=pnginfo) 1037 | 1038 | yield 1039 | 1040 | 1041 | 1042 | ############################################################### 1043 | 1044 | def one_command(char,style,pose,item,header,footer,n): 1045 | 1046 | time_str = get_time_str() 1047 | output_dir = Path("output") / Path(time_str) 1048 | output_dir.mkdir(parents=True) 1049 | 1050 | sd = SDGen(output_dir_path=output_dir) 1051 | 1052 | info = { 1053 | "common" : { 1054 | "prompt":{ 1055 | 1056 | } 1057 | } 1058 | } 1059 | 1060 | if char: 1061 | info["common"]["prompt"]["character_lora"] = char 1062 | if style: 1063 | info["common"]["prompt"]["style_lora"] = style 1064 | if pose: 1065 | info["common"]["prompt"]["pose_lora"] = pose 1066 | if item: 1067 | info["common"]["prompt"]["item_lora"] = item 1068 | if header: 1069 | info["common"]["prompt"]["header"] = header 1070 | if footer: 1071 | info["common"]["prompt"]["footer"] = footer 1072 | 1073 | sd.setup_seq(info) 1074 | 1075 | backup_json_path = Path(output_dir)/Path(time_str + "_gen.json") 1076 | json_text = json.dumps(info, indent=4, ensure_ascii=False) 1077 | backup_json_path.write_text(json_text, encoding="utf-8") 1078 | 1079 | sd.generate(n) 1080 | 1081 | 1082 | def generate_command(json_path:Path, n): 1083 | 1084 | time_str = get_time_str() 1085 | output_dir = Path("output") / Path(time_str) 1086 | output_dir.mkdir(parents=True) 1087 | 1088 | sd = SDGen(output_dir_path=output_dir) 1089 | 1090 | info = {} 1091 | if json_path.is_file(): 1092 | with open(json_path, "r", encoding="utf-8") as f: 1093 | info = json.load(f) 1094 | else: 1095 | raise ValueError(f"invalid path {json_path=}") 1096 | 1097 | sd.setup_seq(info) 1098 | 1099 | backup_json_path = Path(output_dir)/Path(time_str + "_gen.json") 1100 | json_text = json.dumps(info, indent=4, ensure_ascii=False) 1101 | backup_json_path.write_text(json_text, encoding="utf-8") 1102 | 1103 | sd.generate(n) 1104 | 1105 | 1106 | def show_checkpoint_command(): 1107 | sd = SDGen() 1108 | for i,c in enumerate(sd.get_checkpoints()): 1109 | logger.info(f"[{i}] {c}") 1110 | 1111 | def set_default_checkpoint_command(checkpoint_number): 1112 | sd = SDGen() 1113 | cks = sd.get_checkpoints() 1114 | 1115 | if checkpoint_number >= len(cks): 1116 | raise ValueError(f"invalid {checkpoint_number=}") 1117 | 1118 | config_set_default_checkpoint(cks[checkpoint_number]) 1119 | 1120 | logger.info(f"Set to {config_get_default_checkpoint()}") 1121 | 1122 | def show_controlnet_command(): 1123 | sd = SDGen() 1124 | 1125 | models, modules = sd.get_controlnets() 1126 | logger.info(f"== Controlnet Models ==") 1127 | for i,c in enumerate(models): 1128 | logger.info(f"[{i}] {c}") 1129 | logger.info(f"== Controlnet Modules ==") 1130 | for i,c in enumerate(modules): 1131 | logger.info(f"[{i}] {c}") 1132 | 1133 | 1134 | ############################################################### 1135 | _controlnet_model_cache = [] 1136 | _controlnet_module_cache = [] 1137 | 1138 | def get_controlnet_list(): 1139 | global _controlnet_model_cache,_controlnet_module_cache 1140 | if not _controlnet_model_cache: 1141 | sd = SDGen() 1142 | _controlnet_model_cache,_controlnet_module_cache = sd.get_controlnets() 1143 | _controlnet_model_cache = ["none"] + _controlnet_model_cache 1144 | return _controlnet_model_cache,_controlnet_module_cache 1145 | 1146 | 1147 | ############################################################### 1148 | 1149 | _checkpoint_list_cache=[] 1150 | def get_checkpoint_list(): 1151 | global _checkpoint_list_cache 1152 | if not _checkpoint_list_cache: 1153 | sd = SDGen() 1154 | _checkpoint_list_cache = sd.get_checkpoints() 1155 | return _checkpoint_list_cache 1156 | 1157 | 1158 | _latent_upscale_mode_list_cache=[] 1159 | def get_latent_upscale_mode_list(): 1160 | global _latent_upscale_mode_list_cache 1161 | if not _latent_upscale_mode_list_cache: 1162 | sd = SDGen() 1163 | _latent_upscale_mode_list_cache = sd.get_latent_upscale_modes() 1164 | return _latent_upscale_mode_list_cache 1165 | 1166 | 1167 | _sampler_list_cache=[] 1168 | def get_sampler_list(): 1169 | global _sampler_list_cache 1170 | if not _sampler_list_cache: 1171 | sd = SDGen() 1172 | _sampler_list_cache = sd.get_samplers() 1173 | return _sampler_list_cache 1174 | 1175 | 1176 | _scheduler_list_cache=[] 1177 | def get_scheduler_list(): 1178 | global _scheduler_list_cache 1179 | if not _scheduler_list_cache: 1180 | sd = SDGen() 1181 | _scheduler_list_cache = sd.get_schedulers() 1182 | return _scheduler_list_cache 1183 | 1184 | 1185 | _sam_model_list_cache=[] 1186 | def get_sam_model_list(): 1187 | global _sam_model_list_cache 1188 | if not _sam_model_list_cache: 1189 | sd = SDGen() 1190 | _sam_model_list_cache = sd.get_sam_models() 1191 | return _sam_model_list_cache 1192 | 1193 | 1194 | 1195 | 1196 | ############################################################### 1197 | generate_cancel_flag = False 1198 | 1199 | def sd_task(output_dir, info, n, on_progress, on_complete): 1200 | 1201 | sd = SDGen(output_dir_path=output_dir) 1202 | 1203 | sd.setup_seq(info) 1204 | 1205 | progress = 0 1206 | total = sd.get_seq_length() * n 1207 | 1208 | g = sd.generate_generator(n) 1209 | 1210 | for progress in range(1, total+1): 1211 | try: 1212 | next(g) 1213 | except Exception as e: 1214 | logger.error(traceback.format_exc()) 1215 | on_complete("Failed") 1216 | return 1217 | 1218 | on_progress(progress/total) 1219 | if generate_cancel_flag: 1220 | on_complete("Cancel") 1221 | return 1222 | 1223 | on_complete("Success") 1224 | 1225 | 1226 | def async_generate(json_path:Path, n, on_progress, on_complete): 1227 | import threading 1228 | 1229 | global generate_cancel_flag 1230 | 1231 | time_str = get_time_str() 1232 | output_dir = Path("output") / Path(time_str) 1233 | output_dir.mkdir(parents=True) 1234 | 1235 | info = {} 1236 | if json_path.is_file(): 1237 | with open(json_path, "r", encoding="utf-8") as f: 1238 | info = json.load(f) 1239 | else: 1240 | raise ValueError(f"invalid path {json_path=}") 1241 | 1242 | backup_json_path = Path(output_dir)/Path(time_str + "_gen.json") 1243 | json_text = json.dumps(info, indent=4, ensure_ascii=False) 1244 | backup_json_path.write_text(json_text, encoding="utf-8") 1245 | 1246 | generate_cancel_flag = False 1247 | 1248 | thr = threading.Thread(target=sd_task, args=(output_dir, info, n, on_progress, on_complete)) 1249 | thr.start() 1250 | 1251 | return output_dir.absolute() 1252 | 1253 | def cancel_generate(): 1254 | global generate_cancel_flag 1255 | generate_cancel_flag = True 1256 | 1257 | --------------------------------------------------------------------------------