├── image └── directory_for_saved_images.txt ├── stable_cascade └── directory_for_checkpoints.txt ├── app.bat ├── src ├── screenshot.png ├── wildcard_scene.py ├── app_retnet.py └── image_save_file.py ├── install.bat ├── update.bat ├── .gitignore ├── gitignore.txt ├── env └── .env ├── docker-compose.yml ├── requirements.txt ├── Dockerfile ├── LICENSE.md ├── README.md └── app.py /image/directory_for_saved_images.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stable_cascade/directory_for_checkpoints.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app.bat: -------------------------------------------------------------------------------- 1 | call .\venv\Scripts\activate 2 | py app.py 3 | pause -------------------------------------------------------------------------------- /src/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/another-ai/stable_cascade_easy/HEAD/src/screenshot.png -------------------------------------------------------------------------------- /install.bat: -------------------------------------------------------------------------------- 1 | py -m venv venv 2 | call .\venv\Scripts\activate 3 | pip install -r requirements.txt 4 | pause 5 | -------------------------------------------------------------------------------- /update.bat: -------------------------------------------------------------------------------- 1 | git stash 2 | git pull 3 | call .\venv\Scripts\activate 4 | pip install -r requirements.txt 5 | pause -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | **/__pycache__ 3 | *.ckpt 4 | *.safetensors 5 | *.pt 6 | *.pth 7 | *.env 8 | /tmp 9 | /env 10 | /venv 11 | /python_emb 12 | -------------------------------------------------------------------------------- /gitignore.txt: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | **/__pycache__ 3 | *.ckpt 4 | *.safetensors 5 | *.pt 6 | *.pth 7 | *.env 8 | /tmp 9 | /env 10 | /venv 11 | /python_emb 12 | -------------------------------------------------------------------------------- /env/.env: -------------------------------------------------------------------------------- 1 | # default values: 2 | checkpoint_basename=./stable_cascade/ 3 | negative_prompt= 4 | sampler=DDPMWuerstchenScheduler 5 | batch_size=1 6 | random_seed=true 7 | input_seed=1234 8 | width=1280 9 | height=1536 10 | guidance_scale=4 11 | num_inference_steps=20 12 | num_inference_steps_decode=12 13 | contrast=1 14 | dynamic_prompt=0 15 | # for Magic Prompt: 16 | banned_words=greyscale,monochrome 17 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.9' 2 | 3 | services: 4 | app: 5 | build: . 6 | ports: 7 | - "${WEBUI_PORT:-7860}:7860" 8 | stop_signal: SIGKILL 9 | tty: true 10 | volumes: 11 | - ./data/.cache:/root/.cache 12 | - ./data/image:/stable-cascade/image 13 | deploy: 14 | resources: 15 | reservations: 16 | devices: 17 | - driver: nvidia 18 | device_ids: ['0'] 19 | capabilities: [compute, utility] 20 | environment: 21 | - CLI_ARGS= 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/cu118 2 | torch==2.2.1+cu118 3 | torchvision==0.17.1+cu118 4 | --find-links https://download.pytorch.org/whl/torch_stable.html 5 | accelerate==0.27.2 6 | gradio==4.23.0 7 | python-dotenv==1.0.1 8 | transformers==4.38.0 9 | numpy==1.26.4 10 | kornia==0.7.1 11 | insightface==0.7.3 12 | llama-cpp-python==0.2.26 13 | opencv-python==4.9.0.80 14 | tqdm==4.66.3 15 | matplotlib==3.8.2 16 | webdataset==0.2.86 17 | wandb==0.16.3 18 | munch==4.0.0 19 | onnxruntime==1.17.0 20 | einops==0.7.0 21 | onnx2torch==1.5.13 22 | peft==0.8.2 23 | diffusers==0.27.2 24 | warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git 25 | torchtools @ git+https://github.com/pabloppp/pytorch-tools 26 | -------------------------------------------------------------------------------- /src/wildcard_scene.py: -------------------------------------------------------------------------------- 1 | import random 2 | def wildcard_scene_def(directory="./wildcards",file_name="example"): 3 | if file_name != "": 4 | if directory[-1] != "/": 5 | directory = directory + "/" 6 | directory_path = f"{directory}{file_name}.txt" 7 | try: 8 | with open(directory_path, "r") as file: 9 | lines = file.readlines() 10 | if lines: 11 | random_line = random.choice(lines) 12 | return random_line.strip() 13 | else: 14 | print("The file is empty.") 15 | except FileNotFoundError: 16 | print(f"The file {file_name}.txt was not found.") 17 | else: 18 | print("The file name cannot be empty.") 19 | return "" -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive PIP_PREFER_BINARY=1 4 | 5 | RUN apt-get update && apt-get install -y git gcc g++ && apt-get clean 6 | 7 | ENV ROOT=/stable-cascade 8 | RUN --mount=type=cache,target=/root/.cache/pip \ 9 | git clone https://github.com/another-ai/stable_cascade_easy.git ${ROOT} 10 | 11 | WORKDIR ${ROOT} 12 | 13 | RUN pip install -r requirements.txt 14 | RUN pip install git+https://github.com/kashif/diffusers.git@a3dc21385b7386beb3dab3a9845962ede6765887 15 | 16 | RUN sed -i "s/demo.launch(inbrowser=True)/demo.launch(server_name='0.0.0.0',share=False)/" app.py 17 | RUN sed -i "17i torch.backends.cuda.enable_flash_sdp(False)" app.py 18 | RUN sed -i "17i torch.backends.cuda.enable_mem_efficient_sdp(False)" app.py 19 | 20 | ENV NVIDIA_VISIBLE_DEVICES=all PYTHONPATH="${PYTHONPATH}:${PWD}" CLI_ARGS="" 21 | EXPOSE 7860 22 | CMD python app.py ${CLI_ARGS} 23 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Stability AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/app_retnet.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer 2 | import re 3 | 4 | def contains_banned_word(word, banned_words,prompt_chara): 5 | banned_words_text = banned_words 6 | if prompt_chara: 7 | banned_words_text = banned_words_text + ["chibi","eyes","hair","heterochromia","mask","mole"] 8 | for banned_word_text in banned_words_text: 9 | if re.search(banned_word_text, word): 10 | return True 11 | return False 12 | 13 | def replace_word(text, banned_words=[],prompt_chara=False): 14 | 15 | cleaned_words = [word for word in text.split(",") if not contains_banned_word(word, banned_words, prompt_chara)] 16 | 17 | cleaned_text = ",".join(cleaned_words) 18 | 19 | return cleaned_text 20 | 21 | def main_def(prompt_input, max_tokens=256, DEVICE="cpu", banned_words=[], prompt_chara=False): 22 | # if prompt_input == "": 23 | # prompt_input = "1woman" 24 | 25 | MODEL_NAME = "isek-ai/SDPrompt-RetNet-v2-beta" 26 | 27 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) 28 | model = AutoModelForCausalLM.from_pretrained( 29 | MODEL_NAME, 30 | trust_remote_code=True, 31 | ).to(DEVICE) 32 | 33 | streamer = TextStreamer(tokenizer) 34 | 35 | prompt_input = ""+prompt_input 36 | 37 | inputs = tokenizer(prompt_input, return_tensors="pt", add_special_tokens=False)["input_ids"] 38 | 39 | print(f"Token={max_tokens}") 40 | token_ = model.generate( 41 | inputs, 42 | max_new_tokens=max_tokens, 43 | do_sample=True, 44 | top_p=0.9, 45 | top_k=20, 46 | temperature=0.9, 47 | streamer=streamer, 48 | ) 49 | generated_text = tokenizer.decode(token_[0], skip_special_tokens=True).strip() 50 | if len(banned_words) > 0: 51 | generated_text = replace_word(generated_text, banned_words, prompt_chara) 52 | generated_text = generated_text.replace("", "").replace("", "").replace(",,,",",").replace(",,", ",").replace(", ", ",").replace(" ", "_") 53 | print(generated_text) 54 | return generated_text 55 | 56 | if __name__ == "__main__": 57 | main_def("") -------------------------------------------------------------------------------- /src/image_save_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime as date_time 3 | import re 4 | from PIL import Image 5 | from PIL.PngImagePlugin import PngInfo 6 | 7 | def count_file(directory_path_temp): 8 | unique_id_temp = 0 9 | existing_files = len([f for f in os.listdir(directory_path_temp) if (f.endswith(".png") or f.endswith(".jpg")) and (os.path.isfile(os.path.join(directory_path_temp, f)))]) 10 | unique_id_temp = existing_files + 1 11 | return unique_id_temp 12 | 13 | def count_folders(directory_path_temp, new_folder): 14 | unique_id_temp = 0 15 | existing_folders = [ 16 | int(d.split('_')[0]) for d in os.listdir(directory_path_temp) if (os.path.isdir(os.path.join(directory_path_temp, d)) and re.search(r'^\d+', d)) 17 | ] 18 | if existing_folders: 19 | unique_id_temp = max(existing_folders) 20 | if new_folder: 21 | unique_id_temp = unique_id_temp + 1 22 | else: 23 | unique_id_temp = 1 24 | return str(unique_id_temp) 25 | 26 | def add_metadata_file(file_path, txt_file_data_file): 27 | targetImage = Image.open(file_path) 28 | metadata = PngInfo() 29 | metadata.add_text("parameters", txt_file_data_file) 30 | targetImage.save(file_path, pnginfo=metadata) 31 | 32 | 33 | def save_file(image_file, txt_file_data_file): 34 | file_path = "" 35 | if image_file != "": 36 | current_datetime = date_time.now() 37 | current_date = current_datetime.strftime(f"%Y_%m_%d") 38 | current_time = current_datetime.strftime(f"%H_%M_%S") 39 | if not os.path.exists("./image"): 40 | os.makedirs("./image") 41 | if not os.path.exists("./image/" + current_date): 42 | os.makedirs("./image/" + current_date) 43 | directory_path = f"./image/{current_date}" 44 | print(f"Directory:{directory_path}") 45 | if not os.path.exists(directory_path): 46 | os.makedirs(directory_path) 47 | unique_id = count_file(directory_path) 48 | file_name = f"{unique_id}_{current_time}.png" 49 | file_path = f"{directory_path}/{file_name}" 50 | image_file.save(file_path) 51 | add_metadata_file(file_path, txt_file_data_file) 52 | return file_path 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stable_cascade_easy 2 | Text to Img with Stable Cascade(on gradio interface), required less vram than original example on official Hugginface(https://huggingface.co/stabilityai/stable-cascade): 3 | - 44 seconds for a 1280x1536 image with a nvidia RTX3060 with 12 GB VRAM 4 | - 31 seconds with LCM Scheduler for a 1280x1536 image(6 steps on prior module) with a nvidia RTX3060 with 12 GB VRAM 5 | - 26 seconds with LCM Scheduler for a 1024*768 image(6 steps on prior module) with a nvidia RTX3060 with 12 GB VRAM 6 | 7 | ![](src/screenshot.png) 8 | 9 | ## Why is stable_cascade_easy faster than hugginface example of stability ai? 10 | Answer: because stable cascade is composed of two models, many gb each, stability ai example loads both models simultaneously into the gpu vram. While this application loads the first one(prior), creates the image(latents), cleans the vram and sends the image(latents) to the second model(decoder) and then returns the final image and cleans the vram completely... for PC with less than 16 gb of vram without this "trick" all 2 models would not fit in the vram and then you would have to use the system ram with a huge drop in performance(the time drops from 10 minutes to 44 seconds, 1280x1536 with nvidia rtx 3060 12 gb vram) 11 | 12 | # Versions: 13 | - v1.0: First version 14 | - v1.1: Diffusers fix 15 | - v1.2: Scheduler drop down menu(with LCM and DPM++ 2M Karras compatibility) 16 | - v1.2.5: Added "Batch Size", number of images per prompt at the same time 17 | - v1.3: Added "Magic Prompt"(prompt auto-creation) 18 | - v1.4: Now you can select your favourite Stable Cascade Checkpoint(in ```.\stable_cascade``` directory) 19 | 20 | # Issues: 21 | With a different checkpoint than default Stable Cascade, same seed = different images 22 | 23 | # Installation on Windows: 24 | 1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/), checking "Add Python to PATH". 25 | 2. Install [git](https://git-scm.com/download/win). 26 | 3. On terminal: 27 | ```bash 28 | git clone https://github.com/shiroppo/stable_cascade_easy 29 | cd stable_cascade_easy 30 | py -m venv venv 31 | .\venv\Scripts\activate 32 | pip install -r requirements.txt 33 | ``` 34 | # Installation on Linux: 35 | #### (Thanks to @blahouel) 36 | 37 | 1- Clone the repository: 38 | git clone ```https://github.com/another-ai/stable_cascade_easy.git``` 39 | 40 | 2- open terminal in the cloned directory: stable_cascade_easy 41 | type the following prompt: 42 | ```python3 -m venv env``` 43 | 44 | 3- to activate the virtual environment type: 45 | ```source env/bin/activate``` 46 | 47 | 4- your terminal will change to (env) for the new commands. Type the following 48 | ```pip install -r requirements.txt``` 49 | 50 | 5- the git pull will now work without errors. when install is finished type the following 51 | ```python3 app.py``` 52 | 53 | it will take a while to download the models and launch the Web UI in your default browser. 54 | 55 | to launch again you can write a new file in your text editor and save in in the stable_cascade_easy directory. save the file as start.sh 56 | here's the text you need to write in the ```start.sh``` file, you need to change "user" to your own user name: 57 | 58 | ```#!/bin/bash``` 59 | 60 | Specify the paths to your virtual environment and start.py script 61 | ```venv_path="/home/user/stable_cascade_easy"``` 62 | 63 | Open a new Gnome terminal window 64 | ```bash 65 | gnome-terminal --working-directory=$venv_path -- bash -ic 66 | "source env/bin/activate; 67 | python3 app.py; 68 | exec bash" 69 | ``` 70 | 71 | ## Avoid warnings about deprecated packages "Peft"(Linux) - not necessary from v1.3 because peft is already in requirements.txt: 72 | #### (Thanks to @blahouel) 73 | 1- in the installation directory open terminal: 74 | type in the following command: 75 | ```source env/bin/activate``` 76 | 77 | 2- type in the next commend: ```pip install peft``` 78 | 79 | 3- after installation: exit the terminal and restart stable cascade. 80 | 81 | # Run on Windows: 82 | ### Method 1 83 | Double click on ```app.bat``` on stable_cascade_easy directory 84 | ### Method 2 85 | On terminal: 86 | ```bash 87 | .\venv\Scripts\activate 88 | py app.py 89 | ``` 90 | # Update: 91 | 1. ```git pull```(if error: ```git stash``` and after ```git pull```) 92 | 2. ```.\venv\Scripts\activate``` 93 | 3. ```pip install -r requirements.txt``` 94 | 95 | # Magic Prompt 96 | From v1.3 you can choose how many tokens you want and llama will create the prompt for you(based on prompt that you have previously inserted, 0 = magic prompt deactivate, 32/64 = usually good results) 97 | - Thanks to https://huggingface.co/isek-ai/SDPrompt-RetNet-v2-beta, the first time you will use the magic prompt, the system automatically downloads the necessary llama model. 98 | 99 | # Scheduler 100 | You can choose between DDPMWuerstchenScheduler(default) and LCM. Scheduler only for prior model, decode model only works with default scheduler. 101 | 102 | ## Scheduler - DDPMWuerstchenScheduler(default) 103 | Default scheduler, guidance scale recommended: 4, prior steps recommended: 20+ 104 | 105 | ## Scheduler - LCM 106 | LCM can use 6+ steps on prior models so the image creation is even faster, guidance scale recommended: 4, prior steps recommended: from 6 to 18 107 | 108 | # Output 109 | Created images will be saved in the "image" folder 110 | 111 | ## Contrast: 112 | Possibility to change the final image contrast, value from 0.5 to 1.5, no change with value 1(best results from 0.95 to 1.05) 113 | 114 | ## Dimensions(Width and Length) 115 | Multiples of 128 for Stable Cascade, but the app will resize the image for you, so you can choose any size you want 116 | 117 | ## Guidance Scale and Guidance Scale Decode 118 | Choice the value that you want for Guidance Scale(Prior), for the Guidance Scale Decode now is hidden because value different than 0 causes errors and consequent not creation of the image 119 | 120 | ## Code(without gradio): 121 | ```bash 122 | import torch 123 | from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline 124 | import gc 125 | 126 | device = "cuda" 127 | num_images_per_prompt = 1 128 | 129 | prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device) 130 | prior.safety_checker = None 131 | prior.requires_safety_checker = False 132 | 133 | prompt = "a cat" 134 | negative_prompt = "" 135 | 136 | prior_output = prior( 137 | prompt=prompt, 138 | width=1280, 139 | height=1536, 140 | negative_prompt=negative_prompt, 141 | guidance_scale=4.0, 142 | num_images_per_prompt=num_images_per_prompt, 143 | num_inference_steps=20 144 | ) 145 | 146 | del prior 147 | gc.collect() 148 | torch.cuda.empty_cache() 149 | 150 | decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to(device) 151 | decoder.safety_checker = None 152 | decoder.requires_safety_checker = False 153 | 154 | decoder_output = decoder( 155 | image_embeddings=prior_output.image_embeddings.half(), 156 | prompt=prompt, 157 | negative_prompt=negative_prompt, 158 | guidance_scale=0.0, 159 | output_type="pil", 160 | num_inference_steps=12 161 | ).images[0].save("image.png") 162 | 163 | # del decoder 164 | # gc.collect() 165 | # torch.cuda.empty_cache() 166 | ``` 167 | ## Support: 168 | - ko-fi: (https://ko-fi.com/shiroppo) 169 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime as date_time 3 | import sys 4 | import gc 5 | path = os.path.abspath("src") 6 | sys.path.append(path) 7 | import time 8 | import torch 9 | from diffusers import ( 10 | StableCascadeDecoderPipeline, 11 | StableCascadePriorPipeline, 12 | StableCascadeUNet, 13 | ) 14 | from diffusers import LCMScheduler # LCM Scheduler 15 | import gradio as gr 16 | import random 17 | from PIL import ImageEnhance 18 | import image_save_file 19 | from dotenv import load_dotenv 20 | import platform 21 | 22 | if torch.cuda.is_available(): 23 | device = "cuda" 24 | elif torch.backends.mps.is_available(): 25 | device = "mps" 26 | else: 27 | device = "cpu" 28 | 29 | torch_dtype = torch.bfloat16 30 | 31 | def constrast_image(image_file, factor): 32 | im_constrast = ImageEnhance.Contrast(image_file).enhance(factor) 33 | return im_constrast 34 | 35 | def generate_image(checkpoint_basename,checkpoint_prior,checkpoint_decoder,prompt_input,dynamic_prompt,negative_prompt,sampler_choice,num_images_per_prompt,random_seed,input_seed,width,height,guidance_scale,num_inference_steps,num_inference_steps_decoder,contrast): 36 | 37 | def remove_last_comma(sentence): 38 | if len(sentence) > 0 and sentence[-1] == ',': 39 | sentence_without_comma = sentence[:-1] 40 | return sentence_without_comma 41 | else: 42 | return sentence 43 | 44 | def remove_duplicates(words): 45 | words_list = words.split(",") 46 | unique_words = [] 47 | for word in words_list: 48 | if word not in unique_words: 49 | unique_words.append(word) 50 | unique_string = ",".join(unique_words) 51 | return unique_string 52 | 53 | if dynamic_prompt > 0: 54 | if prompt_input != "": 55 | if prompt_input[-1] != ",": 56 | prompt_input = prompt_input + "," 57 | banned_words = os.getenv("banned_words", "").split(",") 58 | import app_retnet 59 | prompt = app_retnet.main_def(prompt_input=prompt_input, max_tokens=dynamic_prompt, DEVICE="cpu", banned_words=banned_words, prompt_chara=False) 60 | prompt = remove_duplicates(prompt.lower()) 61 | prompt = remove_last_comma(prompt) 62 | else: 63 | prompt = prompt_input 64 | 65 | if prompt == "": 66 | prompt = "a cat with the sign: prompt not found, write in black" 67 | negative_prompt = negative_prompt 68 | 69 | if random_seed: 70 | input_seed = random.randint(0, 9999999999) 71 | else: 72 | input_seed = int(input_seed) 73 | 74 | if float(guidance_scale).is_integer(): 75 | guidance_scale = int(guidance_scale) # for txt_file_data correct format 76 | 77 | generator = torch.Generator(device=device).manual_seed(input_seed) 78 | 79 | print(f"Prompt: {prompt}") 80 | 81 | checkpoint_prior_name = "" 82 | if len(checkpoint_prior) < 1: 83 | prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch_dtype).to(device) 84 | checkpoint_prior_name = "stable_cascade" 85 | else: 86 | prior_unet = StableCascadeUNet.from_single_file(checkpoint_basename + checkpoint_prior,torch_dtype=torch_dtype) 87 | prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", prior=prior_unet, torch_dtype=torch_dtype).to(device) 88 | checkpoint_prior_name = os.path.splitext(checkpoint_prior)[0] 89 | 90 | prior.safety_checker = None 91 | prior.requires_safety_checker = False 92 | 93 | resize_pixel_w = width % 128 94 | resize_pixel_h = height % 128 95 | if resize_pixel_w > 0: 96 | width = width - resize_pixel_w 97 | if resize_pixel_h > 0: 98 | height = height - resize_pixel_h 99 | 100 | start_time = time.time() 101 | 102 | match sampler_choice: 103 | case "LCM": 104 | sampler = "LCM" 105 | prior.scheduler = LCMScheduler.from_config(prior.scheduler.config) 106 | case _: 107 | sampler = "DDPMWuerstchenScheduler" # default 108 | 109 | prior_output = prior( 110 | prompt=prompt, 111 | negative_prompt=negative_prompt, 112 | generator=generator, 113 | width=width, 114 | height=height, 115 | guidance_scale=guidance_scale, 116 | num_inference_steps=num_inference_steps, 117 | num_images_per_prompt=num_images_per_prompt 118 | ) 119 | 120 | if len(checkpoint_prior) > 0: 121 | del prior_unet 122 | del prior 123 | gc.collect() 124 | if device=="cuda": 125 | torch.cuda.empty_cache() 126 | 127 | # checkpoint_decoder_name = "" 128 | if len(checkpoint_decoder) < 1: 129 | decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch_dtype).to(device) 130 | # checkpoint_decoder_name = "stable_cascade_decoder" 131 | else: 132 | decoder_unet = StableCascadeUNet.from_single_file(checkpoint_basename + checkpoint_decoder,torch_dtype=torch_dtype) 133 | decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", decoder=decoder_unet, torch_dtype=torch_dtype).to(device) 134 | # checkpoint_decoder_name = os.path.splitext(checkpoint_decoder)[0] 135 | decoder.safety_checker = None 136 | decoder.requires_safety_checker = False 137 | 138 | images = decoder(image_embeddings=prior_output.image_embeddings, 139 | prompt=prompt, 140 | negative_prompt=negative_prompt, 141 | generator=generator, 142 | guidance_scale=0, 143 | num_inference_steps=num_inference_steps_decoder, 144 | output_type="pil" 145 | ).images 146 | 147 | end_time = time.time() 148 | 149 | duration = end_time - start_time 150 | 151 | print(f"Time: {duration} seconds.") 152 | 153 | if resize_pixel_w > 0: 154 | width = width + resize_pixel_w 155 | if resize_pixel_h > 0: 156 | height = height + resize_pixel_h 157 | 158 | for image in images: 159 | if resize_pixel_w > 0 or resize_pixel_h > 0: 160 | image = image.resize((width, height)) 161 | 162 | if contrast != 1: 163 | image = constrast_image(image, contrast) 164 | 165 | txt_file_data=prompt+"\n"+"Negative prompt: "+negative_prompt+"\n"+"Steps: "+str(num_inference_steps)+", Sampler: "+sampler+", CFG scale: "+str(guidance_scale)+", Seed: "+str(input_seed)+", Size: "+str(width)+"x"+str(height)+", Model: "+checkpoint_prior_name 166 | 167 | file_path = image_save_file.save_file(image, txt_file_data) 168 | 169 | if len(checkpoint_decoder) > 0: 170 | del decoder_unet 171 | del decoder 172 | gc.collect() 173 | if device=="cuda": 174 | torch.cuda.empty_cache() 175 | 176 | return_txt_file_data = f"{txt_file_data}\nTime: {duration} seconds." 177 | 178 | yield images, return_txt_file_data 179 | 180 | def stop_gen(checkpoint_prior, checkpoint_decoder): 181 | try: 182 | if len(checkpoint_prior) > 0: 183 | del prior_unet 184 | if len(checkpoint_decoder) > 0: 185 | del decoder_unet 186 | del prior 187 | del decoder 188 | gc.collect() 189 | torch.cuda.empty_cache() 190 | finally: 191 | os.execv(sys.executable, [sys.executable, __file__, "restart"]) 192 | 193 | 194 | def open_dir(dir="image"): 195 | current_datetime = date_time.now() 196 | current_date = current_datetime.strftime(f"%Y_%m_%d") 197 | folder = os.getcwd() + "/" + dir + "/" + current_date 198 | if not os.path.exists(folder): 199 | folder = os.getcwd() + "/" + dir 200 | 201 | if os.path.exists(folder): 202 | operating_system = platform.system() 203 | if operating_system == 'Windows': 204 | os.startfile(folder) 205 | elif operating_system == 'Darwin': 206 | os.system('open "{}"'.format(folder)) 207 | elif operating_system == 'Linux': 208 | os.system('xdg-open "{}"'.format(folder)) 209 | 210 | if __name__ == "__main__": 211 | 212 | load_dotenv("./env/.env") 213 | 214 | default_checkpoint_basename=os.getenv("checkpoint_basename","./stable_cascade/") 215 | default_negative_prompt = os.getenv("negative_prompt", "") 216 | default_sampler = os.getenv("sampler", "DDPMWuerstchenScheduler") 217 | default_batch_size = int(os.getenv("batch_size", "1")) 218 | default_random_seed = os.getenv("random_seed", "true").lower() == "true" 219 | default_input_seed = int(os.getenv("input_seed", "1234")) 220 | default_width = int(os.getenv("width", "768")) 221 | default_height = int(os.getenv("height", "1024")) 222 | default_guidance_scale = float(os.getenv("guidance_scale", "4")) 223 | default_num_inference_steps = int(os.getenv("num_inference_steps", "20")) 224 | default_num_inference_steps_decoder = int(os.getenv("num_inference_steps_decode", "12")) 225 | default_contrast = float(os.getenv("contrast", "1")) 226 | sampler_choice_list= ["DDPMWuerstchenScheduler", "LCM"] 227 | dynamic_prompt=int(os.getenv("dynamic_prompt", "0")) 228 | 229 | generator_image = generate_image 230 | 231 | inbrowser_ = True 232 | if len(sys.argv) > 1: 233 | if sys.argv[1] == "restart": 234 | inbrowser_ = False 235 | 236 | 237 | if default_checkpoint_basename[-1] != "/": 238 | default_checkpoint_basename = default_checkpoint_basename + "/" 239 | 240 | checkpoints_prior_list = [os.path.basename(file) for file in os.listdir(default_checkpoint_basename) if file.endswith(".safetensors")] 241 | checkpoints_decoder_list = [os.path.basename(file) for file in os.listdir(default_checkpoint_basename) if file.endswith(".safetensors")] 242 | 243 | with gr.Blocks() as demo: 244 | with gr.Row(): 245 | with gr.Column(): 246 | title="stable_cascade_easy" 247 | checkpoint_basename = gr.Textbox(value=default_checkpoint_basename, label="Checkpoint Path", visible=False) 248 | checkpoint_prior=gr.Dropdown(value=None, choices=checkpoints_prior_list, allow_custom_value=True, filterable=True, label="Checkpoint(Prior, Stage C), empty for Stable Cascade Default Prior") 249 | checkpoint_decoder=gr.Dropdown(value=None, choices=checkpoints_decoder_list, allow_custom_value=True, filterable=True, label="Checkpoint(Decoder, Stage B), empty for Stable Cascade Default Decoder") 250 | prompt_input=gr.Textbox(value="", lines=4, label="Prompt") 251 | dynamic_prompt = gr.Number(value=dynamic_prompt, label="Magic Prompt(max tokens, 0=off)",step=32,minimum=0,maximum=1024) 252 | negative_prompt=gr.Textbox(value=default_negative_prompt, lines=4, label="Negative Prompt") 253 | sampler_choice=gr.Dropdown(value=default_sampler, choices=sampler_choice_list, label="Scheduler") 254 | num_images_per_prompt=gr.Number(value=default_batch_size, label="Batch Size",step=1,minimum=1,maximum=16) 255 | random_seed=gr.Checkbox(value=default_random_seed, label="Random Seed") 256 | input_seed=gr.Number(value=default_input_seed, label="Input Seed",step=1,minimum=0, maximum=9999999999) 257 | width=gr.Number(value=default_width, label="Width",step=100) 258 | height=gr.Number(value=default_height, label="Height",step=100) 259 | guidance_scale=gr.Number(value=default_guidance_scale, label="Guidance Scale",step=1) 260 | with gr.Row(): 261 | num_inference_steps=gr.Number(value=default_num_inference_steps, label="Steps Prior",step=1) 262 | num_inference_steps_decoder=gr.Number(value=default_num_inference_steps_decoder, label="Steps Decoder",step=1) 263 | contrast=gr.Slider(value=default_contrast, label="Contrast(Default Value = 1)",step=0.05,minimum=0.5,maximum=1.5) 264 | with gr.Row(): 265 | btn_stop_gen = gr.Button(value="Stop") 266 | btn_generate = gr.Button(value="Generate") 267 | with gr.Column(): 268 | output_images=gr.Gallery(allow_preview=True, preview=True, label="Generated Images", show_label=True) 269 | btn_open_dir = gr.Button(value="Open Image Directory") 270 | output_text=gr.Textbox(label="Metadata") 271 | btn_generate.click(generator_image, inputs=[checkpoint_basename, checkpoint_prior,checkpoint_decoder,prompt_input,dynamic_prompt,negative_prompt,sampler_choice,num_images_per_prompt,random_seed,input_seed,width,height,guidance_scale,num_inference_steps,num_inference_steps_decoder,contrast],outputs=[output_images,output_text]) 272 | btn_open_dir.click(open_dir, inputs=[], outputs=[]) 273 | btn_stop_gen.click(stop_gen, inputs=[checkpoint_prior,checkpoint_decoder], outputs=[]) 274 | demo.launch(inbrowser=inbrowser_) 275 | --------------------------------------------------------------------------------