├── .gitignore ├── ColabUI ├── ArtistIndex.py ├── BaseUI.py ├── CanvasHolder.py ├── ColabWrapper.py ├── DiffusionImg2ImgUI.py ├── DiffusionPipelineUI.py ├── FlairView.py ├── HugginfaceModelIndex.py ├── ImageViewer.py ├── Img2ImgRefinerUI.py ├── InpaintRefinerUI.py ├── KandinskyUI.py ├── LoraApplyer.py ├── LoraChoice.py ├── LoraDownloader.py ├── SamplerChoice.py ├── SettingsTabs.py ├── TextualInversionChoice.py └── VaeChoice.py ├── LICENSE ├── README.md ├── artists.txt ├── docs └── ui-example.jpg ├── lpw_stable_diffusion.py ├── model_index.json ├── requirements.txt ├── scripts ├── convert_original_stable_diffusion_to_diffusers.py └── convert_vae_pt_to_diffusers.py ├── sd_diffusers_colab_ui.ipynb ├── sd_diffusers_img2img_ui.ipynb ├── sd_kandinsky_colab_ui.ipynb └── utils ├── downloader.py ├── empty_output.py ├── event.py ├── image_utils.py ├── kohya_lora_loader.py ├── markdown.py ├── preview.py └── ui.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /ColabUI/ArtistIndex.py: -------------------------------------------------------------------------------- 1 | from IPython.display import display, Javascript 2 | from ipywidgets import Dropdown, HTML, HBox, Layout, Text, Button, Accordion, Output 3 | import json 4 | 5 | class ArtistIndex: 6 | __clipboard: str 7 | 8 | def __init__(self, ui, index_path = "artist_index.json"): 9 | with open(index_path) as f: 10 | self.data = json.load(f) 11 | 12 | self.ui = ui 13 | self.example_link = HTML() 14 | 15 | self.__setup_artist_dropdown() 16 | self.__setup_buttons() 17 | self.__setup_foldout() 18 | 19 | def render(self): 20 | """Display ui""" 21 | self.__set_link_from_item(self.artist_dropdown.value) 22 | display(self.foldout, self.output) 23 | 24 | def __setup_artist_dropdown(self): 25 | self.artist_dropdown = Dropdown( 26 | options=[x for x in self.data], 27 | description="Artist:", 28 | ) 29 | self.artist_dropdown.description_tooltip = "Choose artist flair tag to add" 30 | def dropdown_eventhandler(change): 31 | self.__set_link_from_item(change.new) 32 | self.artist_dropdown.observe(dropdown_eventhandler, names='value') 33 | 34 | def __setup_buttons(self): 35 | self.add_button = Button(description="Add", layout=Layout(width='80px')) 36 | def on_add_clicked(b): 37 | self.ui.positive_prompt.value += f", by {self.artist_dropdown.value}" 38 | self.add_button.on_click(on_add_clicked) 39 | 40 | self.copy_button = Button(description='Copy', tooltip="Copy to clipboard", layout=Layout(width='80px')) 41 | self.output = Output() 42 | def copy_event_handler(b): 43 | with self.output: 44 | text = self.__clipboard 45 | display(Javascript(f"navigator.clipboard.writeText('{text}');")) 46 | self.copy_button.on_click(copy_event_handler) 47 | 48 | def __setup_foldout(self): 49 | hbox = HBox([self.artist_dropdown, self.add_button, self.copy_button, self.example_link]) 50 | self.foldout = Accordion(children=[hbox]) 51 | self.foldout.set_title(0, "Flair tags") 52 | self.foldout.selected_index = None 53 | 54 | def __set_link_from_item(self, item): 55 | self.__clipboard = item 56 | key = item.lower().replace(" ", "-") 57 | self.example_link.value = f"Example: {item}" 58 | 59 | def _ipython_display_(self): 60 | self.render() 61 | -------------------------------------------------------------------------------- /ColabUI/BaseUI.py: -------------------------------------------------------------------------------- 1 | from IPython.display import display 2 | from ipywidgets import Textarea, IntText, IntSlider, FloatSlider, HBox, Layout, Label, Button 3 | from PIL.PngImagePlugin import PngInfo 4 | import ipywidgets as widgets 5 | import io 6 | 7 | class BaseUI: 8 | _metadata :str 9 | 10 | def __init__(self): 11 | self.positive_prompt = Textarea(placeholder='Positive prompt...', layout=Layout(width="50%")) 12 | self.negative_prompt = Textarea(placeholder='Negative prompt...', layout=Layout(width="50%")) 13 | self.width_field = IntText(value=512, layout=Layout(width='200px'), description="width: ") 14 | self.height_field = IntText(value=768, layout=Layout(width='200px'), description="height: ") 15 | self.steps_field = IntSlider(value=25, min=1, max=100, description="Steps: ") 16 | self.cfg_field = FloatSlider(value=7, min=1, max=20, step=0.5, description="CFG: ") 17 | self.seed_field = IntText(value=-1, description="seed: ") 18 | self.batch_field = IntText(value=1, layout=Layout(width='150px'), description="Batch size ") 19 | self.count_field = IntText(value=1, layout=Layout(width='150px'), description="Sample count ") 20 | self.clip_skip = None 21 | self.cfg_rescale = 0.0 22 | self.prompt_preprocessors = [] 23 | 24 | def render(self): 25 | """Render UI widgets.""" 26 | l1 = Label("Positive prompt:", layout=Layout(width="50%")) 27 | l2 = Label("Negative prompt:", layout=Layout(width="50%")) 28 | prompts = HBox([self.positive_prompt, self.negative_prompt]) 29 | 30 | def swap_dims(b): 31 | self.width_field.value, self.height_field.value = self.height_field.value, self.width_field.value 32 | 33 | swap_btn = Button(description="⇆", layout=Layout(width='40px')) 34 | swap_btn.on_click(swap_dims) 35 | 36 | size_box = HBox([self.width_field, self.height_field, swap_btn]) 37 | batch_box = HBox([self.batch_field, self.count_field]) 38 | 39 | display(HBox([l1, l2]), prompts, size_box, self.steps_field, self.cfg_field, self.seed_field, batch_box) 40 | 41 | def get_size(self): 42 | """Get current (width, height) values""" 43 | return (self.width_field.value, self.height_field.value) 44 | 45 | def save_image_with_metadata(self, image, path, additional_data = ""): 46 | meta = PngInfo() 47 | meta.add_text("Data", self._metadata + additional_data) 48 | image.save(path, pnginfo=meta) 49 | 50 | def get_positive_prompt(self): 51 | return self.preprocess_prompt(self.positive_prompt.value) 52 | 53 | def get_negative_prompt(self): 54 | return self.preprocess_prompt(self.negative_prompt.value) 55 | 56 | def preprocess_prompt(self, prompt : str): 57 | for preprocessor in self.prompt_preprocessors: 58 | prompt = preprocessor.process(prompt) 59 | return prompt 60 | 61 | @property 62 | def metadata(self): 63 | return self._metadata 64 | 65 | @property 66 | def sample_count(self): 67 | return self.count_field.value 68 | 69 | def get_metadata_string(self): 70 | return f"\nPrompt: {self.positive_prompt.value}\nNegative: {self.negative_prompt.value}\nCGF: {self.cfg_field.value} Steps {self.steps_field.value} " 71 | 72 | def get_dict_to_cache(self): 73 | return { 74 | "prompt" : self.positive_prompt.value, 75 | "n_prompt" : self.negative_prompt.value, 76 | "CGF" : self.cfg_field.value, 77 | "steps" : self.steps_field.value, 78 | "h" : self.height_field.value, 79 | "w" : self.width_field.value, 80 | "batch" : self.batch_field.value, 81 | } 82 | 83 | def load_cache(self, cache): 84 | self.positive_prompt.value = cache["prompt"] 85 | self.negative_prompt.value = cache["n_prompt"] 86 | self.cfg_field.value = cache["CGF"] 87 | self.steps_field.value = cache["steps"] 88 | self.height_field.value = cache["h"] 89 | self.width_field.value = cache["w"] 90 | self.batch_field.value = cache["batch"] 91 | 92 | def _ipython_display_(self): 93 | self.render() 94 | -------------------------------------------------------------------------------- /ColabUI/CanvasHolder.py: -------------------------------------------------------------------------------- 1 | from ipycanvas import MultiCanvas, hold_canvas 2 | from ipywidgets import ColorPicker, IntSlider, link, AppLayout, HBox, VBox, Button, Checkbox 3 | from ipywidgets import Image 4 | import PIL 5 | 6 | class DrawContext: 7 | def __init__(self): 8 | self.drawing = False 9 | self.position = None 10 | self.shape = [] 11 | 12 | class CanvasHolder: 13 | def __init__(self, width=512, height=512, line_color="#749cb8"): 14 | self.multicanvas = MultiCanvas(2, width=width, height=height) 15 | self.canvas = self.multicanvas[1] #foreground 16 | self.background = self.multicanvas[0] 17 | self.context = DrawContext() 18 | self.save_file = "temp.png" 19 | self.scale_factor = 1 20 | 21 | self.canvas.on_mouse_down(self.on_mouse_down) 22 | self.canvas.on_mouse_move(self.on_mouse_move) 23 | self.canvas.on_mouse_up(self.on_mouse_up) 24 | 25 | self.picker = ColorPicker(description="Color:", value=line_color) 26 | link((self.picker, "value"), (self.canvas, "stroke_style")) 27 | link((self.picker, "value"), (self.canvas, "fill_style")) 28 | self.canvas.stroke_style = line_color 29 | 30 | self.line_width_slider = IntSlider(5, min=1, max=20, description="Line Width:") 31 | link((self.line_width_slider, "value"), (self.canvas, "line_width")) 32 | 33 | self.fill_check = Checkbox(value=False, description="Fill shape") 34 | 35 | self.clear_btn = Button(description="Clear") 36 | def clear_handler(b): 37 | self.set_dirty() 38 | with hold_canvas(): self.canvas.clear() 39 | self.clear_btn.on_click(clear_handler) 40 | 41 | self.clear_back_btn = Button(description="Clear Background") 42 | def clear_background_handler(b): 43 | with hold_canvas(): self.background.clear() 44 | self.clear_back_btn.on_click(clear_background_handler) 45 | 46 | self.save_btn = Button(description="Save") 47 | def save_handler(b): 48 | with hold_canvas(): self.save_mask() 49 | self.save_btn.on_click(save_handler) 50 | 51 | def save_to_file(*args, **kwargs): 52 | self.canvas.to_file(self.save_file) 53 | self.canvas.observe(save_to_file, "image_data") 54 | 55 | def on_mouse_down(self, x, y): 56 | self.set_dirty() 57 | self.context.drawing = True 58 | self.context.position = (x, y) 59 | self.context.shape = [self.context.position] 60 | 61 | def on_mouse_move(self, x, y): 62 | if not self.context.drawing: return 63 | self.context.shape.append((x, y)) 64 | with hold_canvas(): 65 | self.canvas.stroke_line(self.context.position[0], self.context.position[1], x, y) 66 | if self.canvas.line_width > 3: self.canvas.fill_circle(x, y, 0.5 * self.canvas.line_width) 67 | self.context.position = (x, y) 68 | 69 | def on_mouse_up(self, x, y): 70 | self.context.drawing = False 71 | with hold_canvas(): 72 | self.canvas.stroke_line(self.context.position[0], self.context.position[1], x, y) 73 | if self.fill_check.value: self.canvas.fill_polygon(self.context.shape) 74 | self.context.shape = [] 75 | 76 | def add_background(self, filepath, fit_image = False, fit_canvas = True, scale_factor = 1): 77 | with hold_canvas(): 78 | self.background.clear() 79 | self.scale_factor = scale_factor 80 | img = Image.from_file(filepath) #Image widget 81 | img_size = PIL.Image.open(filepath).size 82 | if scale_factor != 1: 83 | img_size = (scale_factor * img_size[0], scale_factor * img_size[1]) 84 | 85 | if fit_canvas: 86 | self.set_size(img_size[0], img_size[1]) 87 | 88 | if fit_image: 89 | self.background.draw_image(img, 0,0, width=self.multicanvas.width, height=self.multicanvas.height) 90 | else: 91 | self.background.draw_image(img, 0,0, width=img_size[0], height=img_size[1]) 92 | 93 | def save_mask(self, filepath = "temp.png"): 94 | self.save_file = filepath 95 | self.canvas.sync_image_data = True #optimization hack to not sync image all the time 96 | self.save_btn.button_style = "success" 97 | 98 | def set_dirty(self): 99 | self.canvas.sync_image_data = False 100 | self.save_btn.button_style = "" 101 | 102 | def set_size(self, width, height): 103 | self.set_dirty() 104 | self.multicanvas.width = width 105 | self.multicanvas.height = height 106 | # This stuff gets reset for some reason 107 | self.canvas.stroke_style = self.picker.value 108 | self.canvas.fill_style = self.picker.value 109 | self.canvas.line_width = self.line_width_slider.value 110 | 111 | @property 112 | def render_element(self): 113 | return HBox([self.multicanvas, 114 | VBox([self.picker, self.line_width_slider, self.fill_check, 115 | HBox([self.clear_btn, self.save_btn]), 116 | self.clear_back_btn 117 | ]) 118 | ]) 119 | 120 | def render(self): 121 | display(self.render_element) 122 | 123 | def _ipython_display_(self): 124 | self.render() 125 | -------------------------------------------------------------------------------- /ColabUI/ColabWrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler 4 | from diffusers import AutoencoderKL 5 | from .FlairView import FlairView 6 | from .HugginfaceModelIndex import HugginfaceModelIndex 7 | from .LoraDownloader import LoraDownloader 8 | from .LoraApplyer import LoraApplyer 9 | from .SettingsTabs import SettingsTabs 10 | from .Img2ImgRefinerUI import Img2ImgRefinerUI 11 | from ..utils.preview import get_image_previews, try_create_dir 12 | from ..utils.image_utils import save_image_with_metadata 13 | 14 | class ColabWrapper: 15 | def __init__(self, output_dir: str = "outputs/txt2img", img2img_dir: str = "outputs/img2img", 16 | favourite_dir: str = "outputs/favourite"): 17 | self.output_dir = output_dir 18 | self.img2img_dir = img2img_dir 19 | self.favourite_dir = favourite_dir 20 | self.output_index = 0 21 | self.cache = None 22 | self.lora_cache = dict() 23 | self.pipe = None 24 | self.custom_pipeline = None 25 | 26 | try_create_dir(output_dir) 27 | try_create_dir(favourite_dir) 28 | 29 | def render_model_index(self, filepath: str): 30 | self.model_index = HugginfaceModelIndex(filepath) 31 | self.model_index.render() 32 | 33 | def load_model(self, pipeline_interface, custom_pipeline:str = "lpw_stable_diffusion"): 34 | model_id, from_ckpt = self.model_index.get_model_id() 35 | loader_func = pipeline_interface.from_single_file if from_ckpt else pipeline_interface.from_pretrained 36 | 37 | self.custom_pipeline = custom_pipeline 38 | 39 | #TODO we can try catch here, and load file itself if diffusers doesn't want to load it for us 40 | self.pipe = loader_func(model_id, 41 | custom_pipeline=self.custom_pipeline, 42 | torch_dtype=torch.float16).to("cuda") 43 | self.pipe.safety_checker = None 44 | self.pipe.enable_xformers_memory_efficient_attention() 45 | 46 | def load_vae(self, id_or_path: str, subfolder: str): 47 | vae = AutoencoderKL.from_pretrained(id_or_path, subfolder=subfolder, torch_dtype=torch.float16).to("cuda") 48 | self.pipe.vae = vae 49 | 50 | def load_textual_inversion(self, path: str, filename: str): 51 | self.pipe.load_textual_inversion(path, weight_name=filename) 52 | 53 | def load_textual_inversions(self, root_folder: str): 54 | for path, subdirs, files in os.walk(root_folder): 55 | for name in files: 56 | try: 57 | ext = os.path.splitext(name)[1] 58 | if ext != ".pt" and ext != ".safetensors": continue 59 | self.pipe.load_textual_inversion(path, weight_name=name) 60 | print(path, name) 61 | except: pass 62 | 63 | def render_lora_loader(self, output_dir="Lora"): 64 | self.lora_downloader = LoraDownloader(self.pipe, output_dir=output_dir, cache=self.lora_cache) 65 | self.lora_downloader.render() 66 | 67 | def render_lora_ui(self): 68 | self.lora_ui = LoraApplyer(self.pipe, self.lora_downloader) 69 | self.lora_ui.render() 70 | 71 | def render_generation_ui(self, ui): 72 | self.ui = ui 73 | if (self.cache is not None): self.ui.load_cache(self.cache) 74 | 75 | self.settings = SettingsTabs(self) 76 | 77 | self.flair = FlairView(ui) 78 | 79 | self.settings.render() 80 | self.ui.render() 81 | self.flair.render() 82 | 83 | def restore_cached_ui(self): 84 | if (self.cache is not None): self.ui.load_cache(self.cache) 85 | 86 | def generate(self): 87 | self.cache = self.ui.get_dict_to_cache() 88 | paths = [] 89 | 90 | for sample in range(self.ui.sample_count): 91 | results = self.ui.generate(self.pipe) 92 | for i, image in enumerate(results.images): 93 | path = os.path.join(self.output_dir, f"{self.output_index:05}.png") 94 | save_image_with_metadata(image, path, self.ui, f"Batch: {i}\n") 95 | self.output_index += 1 96 | paths.append(path) 97 | 98 | if self.settings.display_previews: 99 | display(get_image_previews(paths, 512, 512, favourite_dir=self.favourite_dir)) 100 | 101 | return paths 102 | 103 | #StableDiffusionXLImg2ImgPipeline 104 | def render_refiner_ui(self, pipeline_interface): 105 | components = self.pipe.components 106 | self.img2img_pipe = pipeline_interface(**components) 107 | 108 | self.refiner_ui = Img2ImgRefinerUI(self.ui) 109 | self.refiner_ui.render() 110 | 111 | def refiner_generate(self, output_dir: str = None): 112 | results = self.refiner_ui.generate(self.img2img_pipe) 113 | paths = [] 114 | 115 | if output_dir is None: output_dir = self.img2img_dir 116 | for i, image in enumerate(results.images): 117 | path = os.path.join(output_dir, f"{self.output_index:05}.png") 118 | save_image_with_metadata(image, path, self.refiner_ui, f"Batch: {i}\n") 119 | self.output_index += 1 120 | paths.append(path) 121 | 122 | if self.settings.display_previews: 123 | display(get_image_previews(paths, 512, 512, favourite_dir=self.favourite_dir)) 124 | return paths 125 | 126 | def render_inpaint_ui(self, pipeline_interface, ui): 127 | components = self.pipe.components 128 | self.inpaint_pipe = pipeline_interface(**components) 129 | 130 | self.inpaint_ui = ui 131 | self.inpaint_ui.render() 132 | 133 | def inpaint_generate(self, output_dir: str = None): 134 | results = self.inpaint_ui.generate(self.inpaint_pipe) 135 | paths = [] 136 | 137 | if output_dir is None: output_dir = self.img2img_dir 138 | for i, image in enumerate(results.images): 139 | path = os.path.join(output_dir, f"{self.output_index:05}.png") 140 | save_image_with_metadata(image, path, self.inpaint_ui, f"Batch: {i}\n") 141 | self.output_index += 1 142 | paths.append(path) 143 | 144 | if self.settings.display_previews: 145 | display(get_image_previews(paths, 512, 512, favourite_dir=self.favourite_dir)) 146 | return paths 147 | -------------------------------------------------------------------------------- /ColabUI/DiffusionImg2ImgUI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from IPython.display import display 3 | from ipywidgets import FloatSlider, Dropdown, HBox, Text, Button, FileUpload 4 | from PIL import Image 5 | import ipywidgets as widgets 6 | import io 7 | from .BaseUI import BaseUI 8 | 9 | class DiffusionImg2ImgUI(BaseUI): 10 | _current_file :str 11 | _current_image :Image 12 | 13 | def __init__(self): 14 | super().__init__() 15 | self.__generator = torch.Generator(device="cuda") 16 | self.strength_field = FloatSlider(value=0.75, min=0, max=1, step=0.05, description="Strength: ") 17 | self.path_field = Text(description="Url:", placeholder="Image path...") 18 | self.upload_button = Button(description='Choose') 19 | self.image_preview = widgets.Image() 20 | 21 | def uploader_eventhandler(b): 22 | self._current_file = self.path_field.value 23 | temp = widgets.Image().from_file(self._current_file) 24 | self.image_preview.value = temp.value 25 | self._current_image = Image.open(self._current_file) 26 | self.set_size(self._current_image) 27 | self.upload_button.on_click(uploader_eventhandler) 28 | 29 | def render(self): 30 | super().render() 31 | display(self.strength_field, HBox([self.path_field, self.upload_button]), self.image_preview) 32 | 33 | def set_size(self, image, scaling_factor=1): 34 | (self.width_field.value, self.height_field.value) = tuple(scaling_factor * x for x in image.size) 35 | 36 | def generate(self, pipe, init_image=None, generator=None): 37 | """Generate images given DiffusionPipeline, and settings set in UI.""" 38 | if self.seed_field.value >= 0: 39 | seed = self.seed_field.value 40 | else: 41 | seed = self.__generator.seed() 42 | 43 | if init_image is None: init_image = self._current_image 44 | 45 | g = torch.cuda.manual_seed(seed) 46 | self._metadata = self.get_metadata_string() + f"Seed: {seed} " 47 | 48 | init_image = init_image.convert('RGB') 49 | init_image = init_image.resize((self.width_field.value, self.height_field.value), resample=Image.BILINEAR) 50 | 51 | results = pipe(image=init_image, 52 | prompt=self.positive_prompt.value, 53 | negative_prompt=self.negative_prompt.value, 54 | num_inference_steps=self.steps_field.value, 55 | num_images_per_prompt = self.batch_field.value, 56 | guidance_scale=self.cfg_field.value, 57 | strength=self.strength_field.value, 58 | generator=g) 59 | return results 60 | 61 | def get_dict_to_cache(self): 62 | cache = super().get_dict_to_cache() 63 | cache["image_path"] = self.path_field.value 64 | return cache 65 | 66 | def load_cache(self, cache): 67 | super().load_cache(cache) 68 | self.path_field.value = cache["image_path"] 69 | -------------------------------------------------------------------------------- /ColabUI/DiffusionPipelineUI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .BaseUI import BaseUI 3 | 4 | class DiffusionPipelineUI(BaseUI): 5 | def __init__(self): 6 | super().__init__() 7 | self.__generator = torch.Generator(device="cuda") 8 | 9 | def generate(self, pipe, generator = None): 10 | """Generate images given DiffusionPipeline, and settings set in UI.""" 11 | if self.seed_field.value >= 0: 12 | seed = self.seed_field.value 13 | else: 14 | seed = self.__generator.seed() 15 | 16 | g = torch.cuda.manual_seed(seed) 17 | self._metadata = self.get_metadata_string() + f"Seed: {seed} " 18 | 19 | results = pipe(prompt=self.get_positive_prompt(), 20 | negative_prompt=self.negative_prompt.value, 21 | num_inference_steps=self.steps_field.value, 22 | num_images_per_prompt = self.batch_field.value, 23 | guidance_scale=self.cfg_field.value, 24 | guidance_rescale=self.cfg_rescale, 25 | generator=g, clip_skip=self.clip_skip, 26 | height=self.height_field.value, width=self.width_field.value) 27 | return results 28 | -------------------------------------------------------------------------------- /ColabUI/FlairView.py: -------------------------------------------------------------------------------- 1 | import os 2 | import typing 3 | from IPython.display import display, Javascript 4 | from ipywidgets import Output, Text, Layout, Button, HBox, VBox, Dropdown, Accordion, Checkbox 5 | from ..utils.empty_output import EmptyOutput 6 | from ..utils.downloader import download_ckpt 7 | from ..utils.ui import SwitchButton 8 | 9 | class FlairItem: 10 | __data : list[str] 11 | 12 | def __init__(self, ui, out:Output = None, index : int = -1): 13 | self.ui = ui 14 | if out is None: out = EmptyOutput() 15 | self.out = out 16 | self.clipboard_output = Output() 17 | self.index = index 18 | 19 | self.url_text = Text(placeholder="Path or url to load", layout=Layout(width="40%")) 20 | if index >= 0: self.url_text.description = f"{index + 1}:" 21 | self.url_text.description_tooltip = "Path to a file with a list of items, or a url" 22 | 23 | self.load_button = Button(description="Load", layout=Layout(width='50px')) 24 | def load_btn_click(b): 25 | self.out.clear_output() 26 | with self.out: 27 | self.load_data(self.url_text.value) 28 | if self.__data is not None and len(self.__data) > 0: 29 | self.switch_to_dropdown() 30 | self.load_button.on_click(load_btn_click) 31 | 32 | self.load_view = HBox([self.url_text, self.load_button]) 33 | 34 | self.flair_dropdown = Dropdown(layout=Layout(width="40%")) 35 | if index >= 0: self.flair_dropdown.description = f"{index + 1}:" 36 | self.add_button = Button(description="Add", layout=Layout(width='80px')) 37 | def add_btn_click(b): 38 | self.ui.positive_prompt.value += f", by {self.flair_dropdown.value}" 39 | self.add_button.on_click(add_btn_click) 40 | 41 | self.copy_button = Button(description="Copy", tooltip="Copy to clipboard", layout=Layout(width='80px')) 42 | def copy_event_handler(b): 43 | with self.clipboard_output: 44 | text = self.flair_dropdown.value 45 | display(Javascript(f"navigator.clipboard.writeText('{text}');")) 46 | self.copy_button.on_click(copy_event_handler) 47 | 48 | self.preprocess_btn = SwitchButton("Preprocess", "Preprocess", Layout(width='100px'), 49 | self.preprocessor_on, self.preprocessor_off) 50 | self.preprocess_btn.button.tooltip = f"Automatically subsitute {{{index + 1}}} for an element of this list in prompt" 51 | self.subsitute_index = 0 52 | self.subsitute_to_random = Checkbox(value=False, description="Random", indent=False) 53 | 54 | self.dropdown_view = HBox([self.flair_dropdown, self.add_button, self.copy_button, 55 | self.preprocess_btn.render_element, self.subsitute_to_random]) 56 | self.dropdown_view.layout.visibility = "hidden" 57 | self.dropdown_view.layout.height = "0px" 58 | 59 | def switch_to_dropdown(self): 60 | self.load_view.layout.visibility = "hidden" 61 | self.dropdown_view.layout.visibility = "visible" 62 | self.dropdown_view.layout.height = self.load_view.layout.height 63 | self.load_view.layout.height = "0px" 64 | self.flair_dropdown.options = self.__data 65 | self.flair_dropdown.description_tooltip = self.url_text.value 66 | 67 | @property 68 | def render_element(self): 69 | return VBox([self.load_view, self.dropdown_view, self.clipboard_output]) 70 | 71 | def render(self): 72 | display(self.render_element) 73 | 74 | def _ipython_display_(self): 75 | self.render() 76 | 77 | def preprocessor_on(self): 78 | self.ui.prompt_preprocessors.append(self) 79 | 80 | def preprocessor_off(self): 81 | self.ui.prompt_preprocessors.remove(self) 82 | 83 | def process(self, prompt:str): 84 | #TODO guard index 85 | item = self.__data[self.subsitute_index] 86 | self.step_subsitute_index() 87 | return prompt.replace(f"{{{self.index + 1}}}", item) 88 | 89 | def set_subsitute_index(self, index:int = 0): 90 | self.subsitute_index = index 91 | 92 | def step_subsitute_index(self): 93 | if self.subsitute_to_random.value: 94 | self.subsitute_index = random.randint(0, len(self.__data)) 95 | else: 96 | self.subsitute_index += 1 97 | if self.subsitute_index >= len(self.__data): self.subsitute_index = 0 98 | 99 | def load_data(self, path:str): 100 | if os.path.isfile(path): 101 | self.__data = self.read_file(path) 102 | else: 103 | f = download_ckpt(path) 104 | self.__data = self.read_file(f) 105 | 106 | @staticmethod 107 | def read_file(path:str): 108 | with open(path) as f: data = f.readlines() 109 | data = [x.strip('\n') for x in data] 110 | return data 111 | 112 | class FlairView: 113 | __items : list[FlairItem] 114 | 115 | def __init__(self, ui): 116 | self.ui = ui 117 | self.vbox = VBox([]) 118 | 119 | self.output = Output(layout={'border': '1px solid black'}) 120 | self.clear_output_button = Button(description="Clear Log") 121 | def on_clear_clicked(b): 122 | self.output.clear_output() 123 | self.clear_output_button.on_click(on_clear_clicked) 124 | right_align_layout = Layout(display='flex', flex_flow='column', align_items='flex-end', width='99%') 125 | 126 | self.add_button = Button(description="Add", layout=Layout(width='80px'), button_style='info') 127 | def add_btn_click(b): 128 | f = FlairItem(self.ui, self.output, len(self.vbox.children)) 129 | self.vbox.children=tuple(list(self.vbox.children) + [f.render_element]) 130 | self.add_button.on_click(add_btn_click) 131 | 132 | self.foldout = Accordion(children=[VBox([self.vbox, self.add_button, self.output, 133 | HBox(children=[self.clear_output_button],layout=right_align_layout)])]) 134 | self.foldout.set_title(0, "Flair tags") 135 | self.foldout.selected_index = None 136 | 137 | @property 138 | def render_element(self): 139 | return self.foldout 140 | 141 | def render(self): 142 | display(self.render_element) 143 | 144 | def _ipython_display_(self): 145 | self.render() 146 | -------------------------------------------------------------------------------- /ColabUI/HugginfaceModelIndex.py: -------------------------------------------------------------------------------- 1 | from IPython.display import display 2 | from ipywidgets import Dropdown, HTML, HBox, Layout, Text, Checkbox 3 | import json 4 | import requests 5 | import werkzeug 6 | import os 7 | from tqdm.notebook import tqdm 8 | from ..utils.downloader import download_ckpt 9 | 10 | class HugginfaceModelIndex: 11 | def __init__(self, filepath = "model_index.json"): 12 | with open(filepath) as f: 13 | self.data = json.load(f) 14 | 15 | self.model_link = HTML() 16 | self.from_ckpt = Checkbox(value=False, description="A1111 format") 17 | 18 | self.__setup_url_field() 19 | self.__setup_model_dropdown() 20 | 21 | def render(self): 22 | """Display ui""" 23 | self.__set_link_from_dict(self.model_dropdown.value) 24 | display(self.model_dropdown, HBox([self.url_text, self.from_ckpt]), self.model_link) 25 | 26 | def get_model_id(self): 27 | """Return model_id/url/local path of the model, and whether it should be loaded with from_ckpt""" 28 | if self.url_text.value != "": 29 | return self.__handle_url_value(self.url_text.value) 30 | return self.data[self.model_dropdown.value]["id"], False 31 | 32 | def __setup_model_dropdown(self): 33 | self.model_dropdown = Dropdown( 34 | options=[x for x in self.data], 35 | description=self.highlight("Model:"), 36 | ) 37 | def dropdown_eventhandler(change): 38 | self.__set_link_from_dict(change.new) 39 | self.model_dropdown.observe(dropdown_eventhandler, names='value') 40 | self.model_dropdown.description_tooltip = "Choose model from model index" 41 | 42 | def __setup_url_field(self): 43 | self.url_text = Text(description="Url:", placeholder='Optional url or path...', layout=Layout(width="50%")) 44 | def url_eventhandler(change): 45 | self.__set_link_from_url(change.new) 46 | self.url_text.observe(url_eventhandler, names='value') 47 | self.url_text.description_tooltip = "Model_id, url, or local path for any other model not in index" 48 | 49 | def __set_link_from_dict(self, key): 50 | self.from_ckpt.value = False 51 | self.from_ckpt.disabled = True 52 | self.url_text.value = "" 53 | self.model_link.value = f"Model info: {key}" 54 | try: 55 | self.model_link.value += f"
Trigger prompt: {self.data[key]['trigger']}" 56 | except: pass 57 | 58 | def __set_link_from_url(self, new_url): 59 | if new_url == "": 60 | self.__set_link_from_dict(self.model_dropdown.value) 61 | self.url_text.description = "Url:" 62 | self.model_dropdown.description = self.highlight("Model:") 63 | self.from_ckpt.value = False 64 | self.from_ckpt.disabled = True 65 | else: 66 | self.model_link.value = f"Model info: link" 67 | self.url_text.description = self.highlight("Url:") 68 | self.model_dropdown.description = "Model:" 69 | if new_url.startswith("https://civitai.com/api/download/models/") or new_url.endswith(".safetensors"): 70 | self.from_ckpt.value = True 71 | elif self.is_huggingface_model_id(new_url): 72 | self.from_ckpt.value = False 73 | elif os.path.exists(new_url): 74 | self.from_ckpt.value = os.path.isfile(new_url) 75 | self.from_ckpt.disabled = False 76 | 77 | def __handle_url_value(self, url): 78 | if self.is_huggingface_model_id(url): 79 | return url, False 80 | elif os.path.exists(url): 81 | return url, self.from_ckpt.value 82 | else: 83 | return download_ckpt(url), self.from_ckpt.value 84 | 85 | @staticmethod 86 | def highlight(str_to_highlight): 87 | return f"{str_to_highlight}" 88 | 89 | @staticmethod 90 | def is_huggingface_model_id(url): 91 | return len(url.split('/')) == 2 92 | 93 | def _ipython_display_(self): 94 | self.render() 95 | -------------------------------------------------------------------------------- /ColabUI/ImageViewer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import isfile, join 3 | from ipywidgets import Dropdown, Button, Image, VBox, HBox 4 | 5 | class ImageViewer: 6 | def __init__(self, path: str = "outputs/favourite", width: int = 512, height: int = 512): 7 | self.path = path 8 | self.files = sorted([f for f in os.listdir(path) if isfile(join(path, f))]) 9 | self.dropdown = Dropdown( 10 | options=[x for x in self.files], 11 | description="Image:", 12 | ) 13 | self.preview = Image( 14 | width=width, 15 | height=height 16 | ) 17 | if len(self.files) > 0: self.preview.value = open(join(self.path, self.files[0]), "rb").read() 18 | 19 | def dropdown_eventhandler(change): 20 | self.preview.value=open(join(self.path, change.new), "rb").read() 21 | self.dropdown.observe(dropdown_eventhandler, names='value') 22 | 23 | self.btn_refresh = Button(description="Refresh") 24 | def btn_handler(b): 25 | self.refresh() 26 | self.btn_refresh.on_click(btn_handler) 27 | 28 | def refresh(self): 29 | value = self.dropdown.value 30 | self.files = sorted([f for f in os.listdir(self.path) if isfile(join(self.path, f))]) 31 | self.dropdown.options = [x for x in self.files] 32 | self.dropdown.value = value 33 | 34 | @property 35 | def render_element(self): 36 | return VBox([HBox([self.dropdown, self.btn_refresh]), self.preview]) 37 | 38 | def render(self): 39 | display(self.render_element) 40 | 41 | def _ipython_display_(self): 42 | self.render() 43 | -------------------------------------------------------------------------------- /ColabUI/Img2ImgRefinerUI.py: -------------------------------------------------------------------------------- 1 | import io 2 | import ipywidgets as widgets #PIL.Image vs ipywidgets.Image 3 | import numpy as np 4 | import os 5 | import torch 6 | from IPython.display import display 7 | from ipywidgets import FloatSlider, Dropdown, HBox, VBox, Text, Button, Label 8 | from PIL import Image 9 | from .BaseUI import BaseUI 10 | from ..utils.image_utils import load_image_metadata 11 | 12 | class Img2ImgRefinerUI: 13 | def __init__(self, ui: BaseUI): 14 | self.__generator = torch.Generator(device="cuda") 15 | self.__base_ui = ui 16 | self.__current_file = "" 17 | self.__current_img = None 18 | self.__preview_visible = True 19 | 20 | self.strength_field = FloatSlider(value=0.35, min=0, max=1, step=0.05, description="Strength: ") 21 | self.upscale_field = FloatSlider(value=1, min=1, max=4, step=0.25, description="Upscale factor: ") 22 | self.path_field = Text(description="Url:", placeholder="Image path...") 23 | self.load_button = Button(description="Choose") 24 | self.load_prompts_button = Button(description="Load prompts") 25 | self.image_preview = widgets.Image(width=300, height=400) 26 | self.hide_button = Button(description="Hide") 27 | self.size_label = Label(value="Size") 28 | 29 | def load_handler(b): 30 | if not os.path.isfile(self.path_field.value): return 31 | self.__current_file = self.path_field.value 32 | self.image_preview.value = open(self.__current_file, "rb").read() 33 | self.__current_img = Image.open(self.__current_file) 34 | self.size_label.value = f"{self.__current_img.size} → {self.get_upscaled_size(self.__current_img.size)}" 35 | self.load_button.on_click(load_handler) 36 | 37 | def load_prompts_handler(b): 38 | if not os.path.isfile(self.path_field.value): return 39 | prompt, negative = load_image_metadata(self.path_field.value) 40 | self.__base_ui.positive_prompt.value = prompt 41 | self.__base_ui.negative_prompt.value = negative 42 | self.load_prompts_button.on_click(load_prompts_handler) 43 | 44 | def hide_handler(b): 45 | self.__preview_visible = not self.__preview_visible 46 | self.image_preview.layout.width = "300px" if self.__preview_visible else "0px" 47 | self.hide_button.description = "Hide" if self.__preview_visible else "Show preview" 48 | self.hide_button.on_click(hide_handler) 49 | 50 | def on_slider_change(change): 51 | if self.__current_img is None: return 52 | self.size_label.value = f"{self.__current_img.size} → {self.get_upscaled_size(self.__current_img.size)}" 53 | self.upscale_field.observe(on_slider_change, 'value') 54 | 55 | def generate(self, pipe, init_image=None, generator=None): 56 | """Generate images given Img2Img Pipeline, and settings set in Base UI and Refiner UI.""" 57 | if self.__base_ui.seed_field.value >= 0: 58 | seed = self.__base_ui.seed_field.value 59 | else: 60 | seed = self.__generator.seed() 61 | 62 | if init_image is None: 63 | init_image = self.__current_img 64 | 65 | init_image = init_image.convert('RGB') 66 | size = self.get_upscaled_size(init_image.size) 67 | init_image = init_image.resize(size, resample=Image.LANCZOS) 68 | 69 | g = torch.cuda.manual_seed(seed) 70 | self._metadata = self.__base_ui.get_metadata_string() + f"\nImg2Img Seed: {seed}, Noise Strength: {self.strength_field.value}, Upscale: {self.upscale_field.value} " 71 | 72 | results = pipe(image=init_image, 73 | prompt=self.__base_ui.positive_prompt.value, 74 | negative_prompt=self.__base_ui.negative_prompt.value, 75 | num_inference_steps=self.__base_ui.steps_field.value, 76 | num_images_per_prompt = self.__base_ui.batch_field.value, 77 | guidance_scale=self.__base_ui.cfg_field.value, 78 | guidance_rescale=self.__base_ui.cfg_rescale, 79 | strength=self.strength_field.value, 80 | width=size[0], height=size[1], 81 | generator=g) 82 | return results 83 | 84 | @property 85 | def metadata(self): 86 | return self._metadata 87 | 88 | def get_upscaled_size(self, size): 89 | return (int(self.upscale_field.value * size[0]), int(self.upscale_field.value * size[1])) 90 | 91 | def __create_color_picker(self): 92 | picker = widgets.ColorPicker(description='Pick a color', value='#000000', concise=False) 93 | btn = Button(description = "Apply") 94 | def load_color_handler(b): 95 | self.__current_img = self.__create_image_from_color(picker) 96 | self.size_label.value = f"{self.__current_img.size} → {self.get_upscaled_size(self.__current_img.size)}" 97 | f = io.BytesIO() 98 | self.__current_img.save(f, "png") 99 | self.image_preview.value = f.getvalue() 100 | btn.on_click(load_color_handler) 101 | 102 | foldout = widgets.Accordion(children=[HBox([picker, btn])]) 103 | foldout.set_title(0, "From solid color") 104 | foldout.selected_index = None 105 | return foldout 106 | 107 | def __create_image_from_color(self, picker): 108 | h = picker.value.lstrip('#') 109 | rgb = tuple(int(h[i:i+2], 16) for i in (0, 2, 4)) 110 | width = int(self.__base_ui.width_field.value) 111 | height = int(self.__base_ui.height_field.value) 112 | x = np.tile(rgb, (height, width, 1)).astype(np.uint8) 113 | return Image.fromarray(x) 114 | 115 | @property 116 | def render_element(self): 117 | return VBox([self.strength_field, self.upscale_field, 118 | HBox([self.path_field, self.load_button, self.load_prompts_button]), 119 | self.__create_color_picker(), 120 | self.size_label, self.image_preview, self.hide_button]) 121 | 122 | def render(self): 123 | display(self.render_element) 124 | 125 | def _ipython_display_(self): 126 | self.render() 127 | -------------------------------------------------------------------------------- /ColabUI/InpaintRefinerUI.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import torch 4 | import numpy as np 5 | from ipywidgets import Text, FloatSlider, HBox, VBox, Button, Accordion 6 | from PIL import Image 7 | from .CanvasHolder import CanvasHolder 8 | 9 | #TODO can blur: https://huggingface.co/docs/diffusers/using-diffusers/inpaint 10 | 11 | class InpaintRefinerUI: 12 | def __init__(self, ui): 13 | self.__generator = torch.Generator(device="cuda") 14 | self.__base_ui = ui 15 | self.__current_file = "" 16 | self.scale_factor = 0.5 17 | 18 | self.canvas_holder = CanvasHolder(200, 200) 19 | self.strength_field = FloatSlider(value=0.35, min=0, max=1, step=0.05, description="Strength: ") 20 | self.path_field = Text(description="Url:", placeholder="Image path...") 21 | self.load_button = Button(description="Choose") 22 | self.load_prompts_button = Button(description="Load prompts") 23 | 24 | def load_handler(b): 25 | if not os.path.isfile(self.path_field.value): return 26 | self.__current_file = self.path_field.value 27 | self.canvas_holder.add_background(self.path_field.value, scale_factor = self.scale_factor) 28 | self.load_button.on_click(load_handler) 29 | 30 | #TODO accoedion doesn't work for some reason 31 | self.accordion = Accordion(children=[self.canvas_holder.render_element]) 32 | self.accordion.set_title(0, "Canvas") 33 | 34 | def generate(self, pipe, init_image=None, mask=None, generator=None): 35 | """Generate images given Inpaint Pipeline, and settings set in Base UI and Refiner UI.""" 36 | if generator is None: generator = self.__generator 37 | if self.__base_ui.seed_field.value >= 0: 38 | seed = self.__base_ui.seed_field.value 39 | else: 40 | seed = generator.seed() 41 | 42 | if init_image is None: 43 | init_image = Image.open(self.__current_file) 44 | size = init_image.size 45 | # TODO check if mask size matches, or always scale to img? 46 | if mask is None: 47 | mask = Image.fromarray(self.canvas_holder.canvas.get_image_data()[:,:,-1]) 48 | if self.scale_factor != 1: 49 | mask = mask.resize((int(mask.size[0] / self.scale_factor), int(mask.size[1] / self.scale_factor))) 50 | 51 | g = torch.cuda.manual_seed(seed) 52 | self._metadata = self.__base_ui.get_metadata_string() + f"\nImg2Img Seed: {seed}, Noise Strength: {self.strength_field.value} " 53 | 54 | results = pipe(image=init_image, 55 | mask_image=mask, 56 | prompt=self.__base_ui.positive_prompt.value, 57 | negative_prompt=self.__base_ui.negative_prompt.value, 58 | num_inference_steps=self.__base_ui.steps_field.value, 59 | num_images_per_prompt = self.__base_ui.batch_field.value, 60 | guidance_scale=self.__base_ui.cfg_field.value, 61 | guidance_rescale=self.__base_ui.cfg_rescale, 62 | strength=self.strength_field.value, 63 | width=size[0], height=size[1], 64 | generator=g) 65 | return results 66 | 67 | @property 68 | def metadata(self): 69 | return self._metadata 70 | 71 | @property 72 | def render_element(self): 73 | return VBox([self.strength_field, 74 | HBox([self.path_field, self.load_button, self.load_prompts_button]), 75 | self.canvas_holder.render_element]) 76 | 77 | def render(self): 78 | display(self.render_element) 79 | 80 | def _ipython_display_(self): 81 | self.render() 82 | -------------------------------------------------------------------------------- /ColabUI/KandinskyUI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .BaseUI import BaseUI 3 | 4 | class KandinskyUI(BaseUI): 5 | def generate(self, pipe, sampler="p_sampler"): 6 | if self.seed_field.value > 0: 7 | torch.manual_seed(self.seed_field.value) 8 | 9 | self._metadata = self.get_metadata_string() + f"Seed: {self.seed_field.value} " 10 | 11 | images = model.generate_text2img(self.positive_prompt.value, 12 | negative_prior_prompt = self.negative_prompt.value, 13 | num_steps=self.steps_field.value, 14 | batch_size=self.batch_field.value, 15 | guidance_scale=self.cfg_field.value, 16 | h=self.height_field.value, w=self.width_field.value, 17 | sampler=sampler) 18 | return images 19 | -------------------------------------------------------------------------------- /ColabUI/LoraApplyer.py: -------------------------------------------------------------------------------- 1 | from IPython.display import display 2 | from ipywidgets import Dropdown, Button, HBox, VBox, Layout, Text, FloatSlider, Output 3 | from ..utils.empty_output import EmptyOutput 4 | 5 | class LoraApplyer: 6 | def __init__(self, colab, out:Output = None, cache = None): 7 | self.colab = colab 8 | if out is None: out = EmptyOutput() 9 | self.out = out 10 | self.__cache = cache 11 | self.__applier_loras = dict() 12 | self.vbox = VBox() 13 | self.is_fused = False 14 | 15 | self.__setup_dropdown() 16 | 17 | self.add_button = Button(description="Add", layout=Layout(width='50px')) 18 | def add_lora(b): 19 | self.out.clear_output() 20 | with out: self.__add_lora(self.dropdown.value, 1) 21 | self.add_button.on_click(add_lora) 22 | 23 | self.fuse_button = Button(description="Fuse", layout=Layout(width='75px'), button_style='success') 24 | def on_fuse_btn(b): 25 | self.out.clear_output() 26 | with out: self.fuse_lora() 27 | self.fuse_button.on_click(on_fuse_btn) 28 | self.fuse_button.tooltip = "Merge loras into the model weights to speed-up inference" 29 | 30 | self.unfuse_button = Button(description="Unfuse", layout=Layout(width='75px'), button_style='danger') 31 | def on_unfuse_btn(b): 32 | self.out.clear_output() 33 | with out: self.unfuse_lora() 34 | self.unfuse_button.on_click(on_unfuse_btn) 35 | self.unfuse_button.tooltip = "Return model weights to the original state" 36 | 37 | self.__set_ui_fuse_state(False) 38 | 39 | def fuse_lora(self): 40 | if self.is_fused: return 41 | self.is_fused = True 42 | self.__set_ui_fuse_state(True) 43 | self.__apply_adapters() 44 | self.colab.pipe.fuse_lora() 45 | 46 | def unfuse_lora(self): 47 | if not self.is_fused: return 48 | self.is_fused = False 49 | self.__set_ui_fuse_state(False) 50 | self.colab.pipe.unfuse_lora() 51 | 52 | def __setup_dropdown(self): 53 | self.dropdown = Dropdown( 54 | options=[x for x in self.__cache], 55 | description="Lora:", 56 | ) 57 | self.dropdown.description_tooltip = "Choose lora to load" 58 | 59 | def update_dropdown(self, new_adapter): 60 | self.dropdown.options = [x for x in self.__cache] 61 | 62 | def __add_lora(self, adapter: str, scale: float): 63 | if adapter in self.__applier_loras: return 64 | self.__applier_loras[adapter] = scale 65 | self.__rerender_items() 66 | self.__apply_adapters() 67 | 68 | def __rerender_items(self): 69 | w = [] 70 | for adapter, scale in self.__applier_loras.items(): 71 | slider = self.__add_lora_scale_slider(adapter, scale) 72 | remove_btn = self.__add_lora_remove_button(adapter) 73 | w.append(HBox([slider, remove_btn])) 74 | 75 | self.vbox.children = w 76 | 77 | def __add_lora_scale_slider(self, adapter, scale): 78 | slider = FloatSlider(min=0, max=1.5, value=scale, step=0.05, description=adapter) 79 | def value_changed(change): 80 | self.__applier_loras[adapter] = change.new 81 | self.__apply_adapters() 82 | slider.observe(value_changed, 'value') 83 | return slider 84 | 85 | def __add_lora_remove_button(self, adapter): 86 | remove_btn = Button(description="X", layout=Layout(width='40px')) 87 | def on_button(b): 88 | del self.__applier_loras[adapter] 89 | self.__rerender_items() 90 | self.__apply_adapters() 91 | remove_btn.on_click(on_button) 92 | return remove_btn 93 | 94 | def __apply_adapters(self): 95 | if self.is_fused: return 96 | self.colab.pipe.enable_lora() #TODO use separate button 97 | self.colab.pipe.set_adapters([x for x in self.__applier_loras.keys()], 98 | adapter_weights=[x for x in self.__applier_loras.values()]) 99 | 100 | def __set_ui_fuse_state(self, value): 101 | self.add_button.disabled = value 102 | self.fuse_button.disabled = value 103 | self.unfuse_button.disabled = not value 104 | self.fuse_button.layout.width = "50px" if value else "75px" 105 | self.unfuse_button.layout.width = "50px" if not value else "75px" 106 | #self.fuse_button.layout.visibility = "hidden" if value else "visible" 107 | #self.unfuse_button.layout.visibility = "visible" if value else "hidden" 108 | 109 | @property 110 | def render_element(self): 111 | return VBox([HBox([self.dropdown, self.add_button, self.fuse_button, self.unfuse_button]), self.vbox]) 112 | 113 | def render(self): 114 | display(self.render_element) 115 | 116 | def _ipython_display_(self): 117 | self.render() 118 | -------------------------------------------------------------------------------- /ColabUI/LoraChoice.py: -------------------------------------------------------------------------------- 1 | import markdown 2 | from ipywidgets import VBox, HTML, Output 3 | from .LoraDownloader import LoraDownloader 4 | from .LoraApplyer import LoraApplyer 5 | 6 | class LoraChoice: 7 | def __init__(self, colab, out:Output = None, download_dir:str = "Lora", lora_cache:dict = None): 8 | self.colab = colab 9 | 10 | self.label_download = HTML(markdown.markdown("## Load Lora")) 11 | self.lora_downloader = LoraDownloader(colab, out, output_dir=download_dir, cache=lora_cache) 12 | self.label_apply = HTML(markdown.markdown("## Apply Lora")) 13 | self.lora_ui = LoraApplyer(colab, out, cache=self.lora_downloader.cache) 14 | 15 | self.lora_downloader.on_load_event.clear_callbacks() 16 | self.lora_downloader.on_load_event.add_callback(self.lora_ui.update_dropdown) 17 | 18 | @property 19 | def render_element(self): 20 | return VBox([self.label_download, self.lora_downloader.render_element, self.label_apply, self.lora_ui.render_element]) 21 | 22 | def render(self): 23 | display(self.render_element) 24 | 25 | def _ipython_display_(self): 26 | self.render() -------------------------------------------------------------------------------- /ColabUI/LoraDownloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gdown 3 | from IPython.display import display 4 | from ipywidgets import Dropdown, Button, HBox, Layout, Text, Output 5 | from ..utils.downloader import download_ckpt 6 | from ..utils.event import Event 7 | from ..utils.empty_output import EmptyOutput 8 | 9 | class LoraDownloader: 10 | __cache = dict() 11 | 12 | def __init__(self, colab, out:Output = None, output_dir:str = "Lora", cache = None): 13 | self.output_dir = output_dir 14 | self.colab = colab 15 | if out is None: out = EmptyOutput() 16 | self.out = out 17 | 18 | if cache is not None: self.load_cache(cache) 19 | 20 | self.url_text = Text(placeholder="Lora url:", layout=Layout(width="40%")) 21 | def url_changed(change): 22 | self.adapter_field.value = "" 23 | self.url_text.observe(url_changed, 'value') 24 | self.url_text.description_tooltip = "Url or filepath to Lora. Supports google drive links." 25 | 26 | self.adapter_field = Text(placeholder="Adapter name", layout=Layout(width='150px')) 27 | self.adapter_field.description_tooltip = "Custom name for Lora. Default is filename." 28 | 29 | self.load_button = Button(description="Load", layout=Layout(width='50px')) 30 | def load(b): 31 | self.out.clear_output() 32 | with self.out: 33 | self.load_lora(self.url_text.value, self.adapter_field.value) 34 | self.load_button.on_click(load) 35 | 36 | self.repair_button = Button(description="Repair", layout=Layout(width='75px'), button_style='info') 37 | def on_repair_btn(b): 38 | self.out.clear_output() 39 | with self.out: self.repair() 40 | self.repair_button.on_click(on_repair_btn) 41 | self.repair_button.tooltip = "Click to refresh list of loaded Loras in case of errors" 42 | 43 | self.on_load_event = Event() 44 | 45 | def load_lora(self, url, adapter_name=""): 46 | if os.path.isfile(url): 47 | filepath = url 48 | elif url.startswith("https://drive.google.com"): 49 | id = gdown.parse_url.parse_url(url)[0] 50 | if not self.output_dir.endswith("/"): self.output_dir += "/" 51 | filepath = gdown.download(f"https://drive.google.com/uc?id={id}", self.output_dir) 52 | elif url.startswith("https://civitai.com/api/download/models/") or url.endswith(".safetensors"): 53 | filepath = download_ckpt(url, self.output_dir) 54 | else: 55 | print("Error") 56 | return 57 | 58 | filename = os.path.basename(filepath) 59 | if adapter_name == "": 60 | adapter_name = os.path.splitext(filename)[0] 61 | self.adapter_field.value = adapter_name 62 | self.colab.pipe.load_lora_weights(self.output_dir, weight_name=filename, adapter_name=adapter_name) 63 | self.__cache[adapter_name] = filepath 64 | 65 | self.on_load_event.invoke(adapter_name) 66 | 67 | def load_cache(self, cache): 68 | for adapter_name, filepath in cache.items(): 69 | try: 70 | filename = os.path.basename(filepath) 71 | self.colab.pipe.load_lora_weights(os.path.dirname(filepath), weight_name=filename, adapter_name=adapter_name) 72 | self.__cache[adapter_name] = filepath 73 | except: pass 74 | 75 | def repair(self): 76 | adapters_dict = self.colab.pipe.get_list_adapters() 77 | print(adapters_dict) 78 | adapters = {x for module in adapters_dict.values() for x in module} 79 | for a in adapters: 80 | self.__cache[a] = None 81 | self.on_load_event.invoke("") #no need to call for each adapter, TODO argument 82 | 83 | @property 84 | def cache(self): 85 | return self.__cache 86 | 87 | @property 88 | def render_element(self): 89 | return HBox([self.url_text, self.adapter_field, self.load_button, self.repair_button]) 90 | 91 | def render(self): 92 | display(self.render_element) 93 | 94 | def _ipython_display_(self): 95 | self.render() 96 | -------------------------------------------------------------------------------- /ColabUI/SamplerChoice.py: -------------------------------------------------------------------------------- 1 | from ipywidgets import Dropdown, Output 2 | from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler 3 | from ..utils.empty_output import EmptyOutput 4 | 5 | class SamplerChoice: 6 | def __init__(self, colab, out:Output = None): 7 | self.colab = colab 8 | 9 | if out is None: out = EmptyOutput() 10 | self.out = out 11 | 12 | self.dropdown = Dropdown( 13 | options=["Euler A", "DPM++", "DPM++ Karras", "UniPC"], 14 | description='Sampler:', 15 | ) 16 | 17 | def dropdown_eventhandler(change): 18 | self.out.clear_output() 19 | with self.out: 20 | self.choose_sampler(self.colab.pipe, change.new) 21 | 22 | self.dropdown.observe(dropdown_eventhandler, names='value') 23 | 24 | def choose_sampler(self, pipe, sampler_name: str): 25 | config = pipe.scheduler.config 26 | match sampler_name: 27 | case "Euler A": sampler = EulerAncestralDiscreteScheduler.from_config(config) 28 | case "DPM++": sampler = DPMSolverMultistepScheduler.from_config(config) 29 | case "DPM++ Karras": 30 | sampler = DPMSolverMultistepScheduler.from_config(config) 31 | sampler.use_karras_sigmas = True 32 | case "UniPC": 33 | sampler = UniPCMultistepScheduler.from_config(config) 34 | case _: raise NameError("Unknown sampler") 35 | pipe.scheduler = sampler 36 | print(f"Sampler '{sampler_name}' chosen") 37 | 38 | @property 39 | def render_element(self): 40 | return self.dropdown 41 | 42 | def render(self): 43 | display(self.render_element) 44 | 45 | def _ipython_display_(self): 46 | self.render() 47 | -------------------------------------------------------------------------------- /ColabUI/SettingsTabs.py: -------------------------------------------------------------------------------- 1 | from ipywidgets import Button, Tab, VBox, Output, Accordion, IntSlider, Text, Checkbox, FloatSlider 2 | from .SamplerChoice import SamplerChoice 3 | from .VaeChoice import VaeChoice 4 | from .TextualInversionChoice import TextualInversionChoice 5 | from .LoraChoice import LoraChoice 6 | import os 7 | 8 | class SettingsTabs: 9 | #TODO config file, default import file 10 | def __init__(self, colab): 11 | self.colab = colab 12 | self.output = Output(layout={'border': '1px solid black'}) 13 | self.clear_output_button = Button(description="Clear Log") 14 | def on_clear_clicked(b): 15 | self.output.clear_output() 16 | self.clear_output_button.on_click(on_clear_clicked) 17 | 18 | self.sampler_choice = SamplerChoice(self.colab, self.output) 19 | self.vae_choice = VaeChoice(self.colab, self.output) 20 | self.ti_choice = TextualInversionChoice(self.colab, self.output) 21 | self.lora_choice = LoraChoice(self.colab, self.output, lora_cache=colab.lora_cache) 22 | self.general_settings = self.__general_settings() 23 | 24 | self.tabs = Tab() 25 | self.tabs.children = [self.sampler_choice.render_element, 26 | self.vae_choice.render_element, 27 | self.ti_choice.render_element, 28 | self.lora_choice.render_element, 29 | self.general_settings] 30 | 31 | for i,x in enumerate(["Sampler", "VAE", "Textual Inversion", "Lora", "Other"]): 32 | self.tabs.set_title(i, x) 33 | 34 | self.accordion = Accordion(children=[VBox([self.tabs, self.clear_output_button, self.output])]) 35 | self.accordion.set_title(0, "Settings") 36 | self.accordion.selected_index = None 37 | 38 | def __general_settings(self): 39 | self.show_preview_checkbox = Checkbox(value=True, description="Show output previews") 40 | 41 | clip_slider = IntSlider(value=self.colab.ui.clip_skip, min=0, max=4, description="Clip Skip") 42 | def clip_value_changed(change): 43 | self.output.clear_output() 44 | with self.output: 45 | self.colab.ui.clip_skip = change.new 46 | print(f"Clip Skip set to {change.new}") 47 | clip_slider.observe(clip_value_changed, "value") 48 | 49 | cfg_rescale_slider = FloatSlider(value=0, min=0, max=1, step = 0.05, description="CFG Rescale") 50 | cfg_rescale_slider.description_tooltip = "Bigger value fixes contrast and overexposure" 51 | def rescale_value_changed(change): 52 | with self.output: self.colab.ui.cfg_rescale = change.new 53 | cfg_rescale_slider.observe(rescale_value_changed, "value") 54 | 55 | favourite_dir = Text(placeholder="Path to folder", description="Favourites:") 56 | favourite_dir.value = self.colab.favourite_dir 57 | def favourite_value_changed(change): 58 | self.output.clear_output() 59 | with self.output: 60 | if (os.path.isdir(change.new)): 61 | self.colab.favourite_dir = change.new 62 | favourite_dir.description = "Favourites:" 63 | print(f"Favourites foulder changed to {change.new}") 64 | else: 65 | favourite_dir.description = self.highlight("Favourites:", color="red") 66 | favourite_dir.observe(favourite_value_changed, "value") 67 | 68 | return VBox([self.show_preview_checkbox, favourite_dir, clip_slider, cfg_rescale_slider]) 69 | 70 | @staticmethod 71 | def highlight(str_to_highlight: str, color: str = "red") -> str: 72 | return f"{str_to_highlight}" 73 | 74 | @property 75 | def display_previews(self): 76 | return self.show_preview_checkbox.value 77 | 78 | @property 79 | def render_element(self): 80 | return self.accordion 81 | 82 | def render(self): 83 | display(self.render_element) 84 | 85 | def _ipython_display_(self): 86 | self.render() 87 | -------------------------------------------------------------------------------- /ColabUI/TextualInversionChoice.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ipywidgets import Text, Button, Layout, Combobox, VBox, Output 3 | from ..utils.empty_output import EmptyOutput 4 | from ..utils.markdown import SpoilerLabel 5 | 6 | class TextualInversionChoice: 7 | def __init__(self, colab, out:Output = None, default_path:str = "/content/embeddings/"): 8 | self.colab = colab 9 | if out is None: out = EmptyOutput() 10 | self.out = out 11 | self.tokens = list() 12 | 13 | self.tooltip_label = SpoilerLabel("Tooltip", "Paste a path to Textual Inversion .pt file, or path to a folder containig them.") 14 | self.path = Text(description="TI:", placeholder='Path to file or folder...', layout=Layout(width="50%")) 15 | self.path.description_tooltip = "Path to a Textual Inversion file or root folder" 16 | if default_path is not None: self.path.value = default_path 17 | 18 | self.load_button = Button(description="Load", button_style='success') 19 | def on_load_clicked(b): 20 | self.out.clear_output() 21 | with self.out: 22 | self.load(self.colab.pipe) 23 | self.load_button.on_click(on_load_clicked) 24 | 25 | self.token_list = Combobox(placeholder="Start typing token", description="Tokens:", ensure_option=True) 26 | 27 | def load(self, pipe): 28 | if self.path.value == "": 29 | raise ValueError("Text field shouldn't be empty") 30 | 31 | if os.path.isfile(self.path.value): 32 | dir, filename = os.path.split(self.path.value) 33 | self.__load_textual_inversion(self.colab.pipe, dir, filename) 34 | elif os.path.isdir(self.path.value): 35 | self.__load_textual_inversions_from_folder(self.colab.pipe, self.path.value) 36 | 37 | #TODO load from file: https://huggingface.co/docs/diffusers/api/loaders/textual_inversion 38 | def __load_textual_inversion(self, pipe, path: str, filename: str): 39 | self.colab.pipe.load_textual_inversion(path, weight_name=filename) 40 | 41 | token = f"<{os.path.splitext(filename)[0]}>" 42 | self.__add_token(token) 43 | print(token) 44 | 45 | def __load_textual_inversions_from_folder(self, pipe, root_folder: str): 46 | n = 0 47 | for path, subdirs, files in os.walk(root_folder): 48 | for name in files: 49 | try: 50 | if os.path.splitext(name)[1] != ".pt": continue 51 | self.colab.pipe.load_textual_inversion(path, weight_name=name) 52 | 53 | token = f"<{os.path.splitext(name)[0]}>" 54 | self.__add_token(token) 55 | print(f"{path}:\t{token}") 56 | n += 1 57 | except: pass #TODO save exception ? 58 | print(f"{n} items added") 59 | 60 | def __add_token(self, token:str): 61 | self.tokens.append(token) 62 | self.token_list.options = self.tokens 63 | 64 | @property 65 | def render_element(self): 66 | return VBox([self.tooltip_label, self.path, self.load_button, self.token_list]) 67 | 68 | def render(self): 69 | display(self.render_element) 70 | 71 | def _ipython_display_(self): 72 | self.render() 73 | -------------------------------------------------------------------------------- /ColabUI/VaeChoice.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from ipywidgets import Text, Layout, Button, HBox, VBox, Output 4 | from diffusers import AutoencoderKL 5 | from ..utils.downloader import download_ckpt 6 | from ..utils.empty_output import EmptyOutput 7 | from ..utils.markdown import SpoilerLabel 8 | 9 | class VaeChoice: 10 | def __init__(self, colab, out:Output = None, default_id:str = "waifu-diffusion/wd-1-5-beta2"): 11 | self.colab = colab 12 | if out is None: out = EmptyOutput() 13 | self.out = out 14 | 15 | self.tooltip_label = SpoilerLabel("Loading single file", "If vae is in single file, just leave 'Subfolder' field empty.") 16 | self.id_text = Text(description="VAE:", placeholder='Id or path...', layout=Layout(width="35%")) 17 | self.id_text.description_tooltip = "Huggingface model id, url, or a path to a root folder/file" 18 | if default_id is not None: self.id_text.value = default_id 19 | self.subfolder_text = Text(description="Subfolder:", layout=Layout(width="25%")) 20 | self.subfolder_text.description_tooltip = "Subfolder name, if using diffusers-style model" 21 | if default_id is not None: self.subfolder_text.value = "vae" 22 | 23 | self.button = Button(description="Load", button_style='success') 24 | 25 | def on_button_clicked(b): 26 | self.out.clear_output() 27 | with self.out: 28 | self.load_vae(self.colab.pipe) 29 | 30 | self.button.on_click(on_button_clicked) 31 | 32 | def load_vae(self, pipe): 33 | if self.id_text.value == "": 34 | raise ValueError("VAE text field shouldn't be empty") 35 | 36 | url = self.id_text.value 37 | if url.endswith(".safetensors") or url.startswith("https://civitai.com/api/download/models/"): 38 | if not os.path.isfile(url): url = download_ckpt(url) #TODO download dir 39 | vae = AutoencoderKL.from_single_file(url, torch_dtype=torch.float16) 40 | else: 41 | if self.subfolder_text.value == "": 42 | raise ValueError("Subfolder text field shouldn't be empty when loading diffusers-style model") 43 | vae = AutoencoderKL.from_pretrained(url, subfolder=self.subfolder_text.value, torch_dtype=torch.float16) 44 | 45 | pipe.vae = vae.to("cuda") 46 | print(f"{self.id_text.value} VAE loaded") 47 | 48 | @property 49 | def render_element(self): 50 | return VBox([self.tooltip_label, HBox([self.id_text, self.subfolder_text]), self.button]) 51 | 52 | def render(self): 53 | display(self.render_element) 54 | 55 | def _ipython_display_(self): 56 | self.render() 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Andrey Sugonyaev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Colab UI for Stable Diffusion 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/oneir0mancer/stable-diffusion-diffusers-colab-ui/blob/main/sd_diffusers_colab_ui.ipynb) 4 | 5 | Colab UI for running Stable Diffusion with 🤗 [diffusers](https://github.com/huggingface/diffusers) library. 6 | 7 | This repository aims to create a GUI using native Colab and IPython widgets. 8 | This eliminates the need for gradio WebUI, which [seems to be prohibited](https://github.com/googlecolab/colabtools/issues/3591) now on Google Colab. 9 | 10 | ![UI example](docs/ui-example.jpg) 11 | 12 | ### Features: 13 | - [X] UI based on IPython widgets 14 | - [X] Stable diffusion 1.x, 2.x, XL 15 | - [X] Load models in from Huggingface, and models in Automatic1111 (*ckpt/safetensors*) format 16 | - [X] Change VAE and sampler 17 | - [X] Load textual inversions 18 | - [x] Load LoRAs 19 | - [x] Img2Img 20 | - [x] Inpainting 21 | 22 | ### SDXL 23 | SDXL is supported to an extent. You can load huggingface models, but loading models in A1111 format may be impossible because of RAM limitation. 24 | Also it looks like textual inversions and loading LoRAs is not supported for now. 25 | 26 | ## Kandinsky 27 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/oneir0mancer/stable-diffusion-diffusers-colab-ui/blob/main/sd_kandinsky_colab_ui.ipynb) 28 | 29 | Colab UI for Kandinsky txt2img and image mixing. 30 | 31 | Original repo: https://github.com/ai-forever/Kandinsky-2 32 | -------------------------------------------------------------------------------- /artists.txt: -------------------------------------------------------------------------------- 1 | Alphonse Mucha 2 | Coles Phillips 3 | Frank Frazetta 4 | Alex Russell Flint 5 | Jeffrey Catherine Jones 6 | Ilya Kuvshinov 7 | Josephine Wall 8 | anne bachelier 9 | Russ Mills 10 | Ivan Bilibin 11 | Helen Frankenthaler 12 | Sergio Aragones 13 | Don Bluth 14 | Carl Barks 15 | Gemma Correll 16 | Tex Avery 17 | Matt Bors 18 | Quentin Blake 19 | Richard Scarry 20 | Alex Toth 21 | Steve Dillon 22 | Leone Frollo 23 | Elsa Beskow 24 | Howard Chaykin 25 | Rafael Albuquerque 26 | Hirohiko Araki 27 | Hiromu Arakawa 28 | Arthur Rackham 29 | Henri Fantin-Latour 30 | Eddie Campbell 31 | Abigail Larson 32 | Kaethe Butcher 33 | Aubrey Beardsley 34 | Kate Greenaway 35 | John Blanche 36 | Simon Bisley 37 | -------------------------------------------------------------------------------- /docs/ui-example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oneir0mancer/stable-diffusion-diffusers-colab-ui/fd69311417ddbf3f051d893af445d10e14b6542b/docs/ui-example.jpg -------------------------------------------------------------------------------- /model_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "Stable Diffusion v1.5" : { "id" : "runwayml/stable-diffusion-v1-5" }, 3 | "Stable Diffusion XL" : { "id" : "stabilityai/stable-diffusion-xl-base-1.0" }, 4 | "Animagine XL" : { "id" : "cagliostrolab/animagine-xl-3.0" }, 5 | "Pony Diffusion XL" : { "id" : "stablediffusionapi/pony-diffusion-v6-xl" }, 6 | "Counterfeit-V2.5" : { "id" : "gsdf/Counterfeit-V2.5" }, 7 | "Analog Diffusion" : { "id" : "wavymulder/Analog-Diffusion", "trigger": "analog style" }, 8 | "Dreamlike Diffusion" : { "id" : "dreamlike-art/dreamlike-diffusion-1.0", "trigger": "dreamlikeart" }, 9 | "Dreamlike Photoreal 2.0" : { "id" : "dreamlike-art/dreamlike-photoreal-2.0" }, 10 | "Openjourney v4" : { "id" : "prompthero/openjourney-v4" }, 11 | "Pastel Mix" : { "id" : "andite/pastel-mix" }, 12 | "Pony Diffusion v4" : { "id" : "AstraliteHeart/pony-diffusion-v4" }, 13 | "ACertainThing" : { "id" : "JosephusCheung/ACertainThing" }, 14 | "BPModel" : { "id" : "Crosstyan/BPModel" }, 15 | "SomethingV2.2" : { "id" : "NoCrypt/SomethingV2_2" }, 16 | "Anything v3.0" : { "id" : "Linaqruf/anything-v3.0" }, 17 | "Anything v4.0 mix" : { "id" : "andite/anything-v4.0" }, 18 | "Waifu Diffusion v1.4 " : { "id" : "hakurei/waifu-diffusion" }, 19 | "Waifu Diffusion v1.5 " : { "id" : "waifu-diffusion/wd-1-5-beta3" } 20 | } 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.5.1 2 | torchvision==0.20.1 3 | xformers 4 | accelerate 5 | diffusers 6 | omegaconf 7 | peft 8 | pytorch-lightning 9 | safetensors 10 | transformers 11 | -------------------------------------------------------------------------------- /scripts/convert_original_stable_diffusion_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Conversion script for the LDM checkpoints. """ 16 | 17 | import argparse 18 | 19 | import torch 20 | 21 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument( 28 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 29 | ) 30 | # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml 31 | parser.add_argument( 32 | "--original_config_file", 33 | default=None, 34 | type=str, 35 | help="The YAML config file corresponding to the original architecture.", 36 | ) 37 | parser.add_argument( 38 | "--num_in_channels", 39 | default=None, 40 | type=int, 41 | help="The number of input channels. If `None` number of input channels will be automatically inferred.", 42 | ) 43 | parser.add_argument( 44 | "--scheduler_type", 45 | default="pndm", 46 | type=str, 47 | help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']", 48 | ) 49 | parser.add_argument( 50 | "--pipeline_type", 51 | default=None, 52 | type=str, 53 | help=( 54 | "The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'" 55 | ". If `None` pipeline will be automatically inferred." 56 | ), 57 | ) 58 | parser.add_argument( 59 | "--image_size", 60 | default=None, 61 | type=int, 62 | help=( 63 | "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" 64 | " Base. Use 768 for Stable Diffusion v2." 65 | ), 66 | ) 67 | parser.add_argument( 68 | "--prediction_type", 69 | default=None, 70 | type=str, 71 | help=( 72 | "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable" 73 | " Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2." 74 | ), 75 | ) 76 | parser.add_argument( 77 | "--extract_ema", 78 | action="store_true", 79 | help=( 80 | "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" 81 | " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" 82 | " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." 83 | ), 84 | ) 85 | parser.add_argument( 86 | "--upcast_attention", 87 | action="store_true", 88 | help=( 89 | "Whether the attention computation should always be upcasted. This is necessary when running stable" 90 | " diffusion 2.1." 91 | ), 92 | ) 93 | parser.add_argument( 94 | "--from_safetensors", 95 | action="store_true", 96 | help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", 97 | ) 98 | parser.add_argument( 99 | "--to_safetensors", 100 | action="store_true", 101 | help="Whether to store pipeline in safetensors format or not.", 102 | ) 103 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 104 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 105 | parser.add_argument( 106 | "--stable_unclip", 107 | type=str, 108 | default=None, 109 | required=False, 110 | help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.", 111 | ) 112 | parser.add_argument( 113 | "--stable_unclip_prior", 114 | type=str, 115 | default=None, 116 | required=False, 117 | help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.", 118 | ) 119 | parser.add_argument( 120 | "--clip_stats_path", 121 | type=str, 122 | help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.", 123 | required=False, 124 | ) 125 | parser.add_argument( 126 | "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint." 127 | ) 128 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.") 129 | args = parser.parse_args() 130 | 131 | pipe = download_from_original_stable_diffusion_ckpt( 132 | checkpoint_path=args.checkpoint_path, 133 | original_config_file=args.original_config_file, 134 | image_size=args.image_size, 135 | prediction_type=args.prediction_type, 136 | model_type=args.pipeline_type, 137 | extract_ema=args.extract_ema, 138 | scheduler_type=args.scheduler_type, 139 | num_in_channels=args.num_in_channels, 140 | upcast_attention=args.upcast_attention, 141 | from_safetensors=args.from_safetensors, 142 | device=args.device, 143 | stable_unclip=args.stable_unclip, 144 | stable_unclip_prior=args.stable_unclip_prior, 145 | clip_stats_path=args.clip_stats_path, 146 | controlnet=args.controlnet, 147 | ) 148 | 149 | if args.half: 150 | pipe.to(torch_dtype=torch.float16) 151 | 152 | if args.controlnet: 153 | # only save the controlnet model 154 | pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 155 | else: 156 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) -------------------------------------------------------------------------------- /scripts/convert_vae_pt_to_diffusers.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import io 3 | 4 | import requests 5 | import torch 6 | from omegaconf import OmegaConf 7 | 8 | from diffusers import AutoencoderKL 9 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( 10 | assign_to_checkpoint, 11 | conv_attn_to_linear, 12 | create_vae_diffusers_config, 13 | renew_vae_attention_paths, 14 | renew_vae_resnet_paths, 15 | ) 16 | 17 | 18 | def custom_convert_ldm_vae_checkpoint(checkpoint, config): 19 | vae_state_dict = checkpoint 20 | 21 | new_checkpoint = {} 22 | 23 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] 24 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] 25 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] 26 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] 27 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] 28 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] 29 | 30 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] 31 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] 32 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] 33 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] 34 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] 35 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] 36 | 37 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] 38 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] 39 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] 40 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] 41 | 42 | # Retrieves the keys for the encoder down blocks only 43 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) 44 | down_blocks = { 45 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) 46 | } 47 | 48 | # Retrieves the keys for the decoder up blocks only 49 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) 50 | up_blocks = { 51 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) 52 | } 53 | 54 | for i in range(num_down_blocks): 55 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] 56 | 57 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: 58 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( 59 | f"encoder.down.{i}.downsample.conv.weight" 60 | ) 61 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( 62 | f"encoder.down.{i}.downsample.conv.bias" 63 | ) 64 | 65 | paths = renew_vae_resnet_paths(resnets) 66 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} 67 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 68 | 69 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] 70 | num_mid_res_blocks = 2 71 | for i in range(1, num_mid_res_blocks + 1): 72 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] 73 | 74 | paths = renew_vae_resnet_paths(resnets) 75 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 76 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 77 | 78 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] 79 | paths = renew_vae_attention_paths(mid_attentions) 80 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 81 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 82 | conv_attn_to_linear(new_checkpoint) 83 | 84 | for i in range(num_up_blocks): 85 | block_id = num_up_blocks - 1 - i 86 | resnets = [ 87 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key 88 | ] 89 | 90 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: 91 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ 92 | f"decoder.up.{block_id}.upsample.conv.weight" 93 | ] 94 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ 95 | f"decoder.up.{block_id}.upsample.conv.bias" 96 | ] 97 | 98 | paths = renew_vae_resnet_paths(resnets) 99 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} 100 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 101 | 102 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] 103 | num_mid_res_blocks = 2 104 | for i in range(1, num_mid_res_blocks + 1): 105 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] 106 | 107 | paths = renew_vae_resnet_paths(resnets) 108 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 109 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 110 | 111 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] 112 | paths = renew_vae_attention_paths(mid_attentions) 113 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 114 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 115 | conv_attn_to_linear(new_checkpoint) 116 | return new_checkpoint 117 | 118 | 119 | def vae_pt_to_vae_diffuser( 120 | checkpoint_path: str, 121 | output_path: str, 122 | ): 123 | # Only support V1 124 | r = requests.get( 125 | " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" 126 | ) 127 | io_obj = io.BytesIO(r.content) 128 | 129 | original_config = OmegaConf.load(io_obj) 130 | image_size = 512 131 | device = "cuda" if torch.cuda.is_available() else "cpu" 132 | checkpoint = torch.load(checkpoint_path, map_location=device) 133 | 134 | # Convert the VAE model. 135 | vae_config = create_vae_diffusers_config(original_config, image_size=image_size) 136 | converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint["state_dict"], vae_config) 137 | 138 | vae = AutoencoderKL(**vae_config) 139 | vae.load_state_dict(converted_vae_checkpoint) 140 | vae.save_pretrained(output_path) 141 | 142 | 143 | if __name__ == "__main__": 144 | parser = argparse.ArgumentParser() 145 | 146 | parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.") 147 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.") 148 | 149 | args = parser.parse_args() 150 | 151 | vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path) -------------------------------------------------------------------------------- /sd_diffusers_colab_ui.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "cellView": "form", 18 | "id": "x-z4MGy9rz9I" 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "#@title #Install dependencies\n", 23 | "!git clone https://github.com/oneir0mancer/stable-diffusion-diffusers-colab-ui.git StableDiffusionUi\n", 24 | "\n", 25 | "!pip install -r StableDiffusionUi/requirements.txt\n", 26 | "!wget https://github.com/huggingface/diffusers/raw/refs/heads/main/examples/community/lpw_stable_diffusion_xl.py -O StableDiffusionUi/lpw_stable_diffusion_xl.py\n", 27 | "!wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml\n", 28 | "\n", 29 | "!mkdir -p outputs/{txt2img,img2img}\n", 30 | "!git clone https://huggingface.co/embed/negative /content/embeddings/negative\n", 31 | "!git clone https://huggingface.co/embed/lora /content/Lora/positive\n" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": { 37 | "id": "QRvNp3ElcaA3" 38 | }, 39 | "source": [ 40 | "# Creating model pipeline" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": { 46 | "id": "dVlvLpShS9eo" 47 | }, 48 | "source": [ 49 | "\n", 50 | "
\n", 51 | "Description
TLDR: use dropdown or just paste stuff
\n", 52 | "\n", 53 | "---\n", 54 | "\n", 55 | "You can just choose a model from a dropdown of popular models, which will also give you a link to a model page on [huggingface](https://huggingface.co) and any trigger words you need to use in the prompt.\n", 56 | "\n", 57 | "Alternatively, you can paste a **model_id**, **url**, or **local path** for any model from huggingface:\n", 58 | "\n", 59 | "* **model_id** looks like this: `gsdf/Counterfeit-V2.5`.\n", 60 | "
You can use it for any *proper* diffusers model in huggingface, i.e. ones that have a lot of folders in their **Files**, and a file called **model_index.json**\n", 61 | "
`🔲 A1111 format` toggle should be **off**.\n", 62 | "\n", 63 | "* **url** from HuggingFace may look like this:\n", 64 | "```\n", 65 | "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors\n", 66 | "```\n", 67 | "Use it when the model is in Automatic1111 format (as in, just a single file).\n", 68 | "
`✅ A1111 format` toggle should be **on**.\n", 69 | "\n", 70 | "* **url** from CivitAI should look like this:\n", 71 | "```\n", 72 | "https://civitai.com/api/download/models/XXXXXX\n", 73 | "```\n", 74 | "(the link you get from \"Download\" button)\n", 75 | "
`✅ A1111 format` toggle should be **on**.\n", 76 | "\n", 77 | "* **local path** is just a path to a model folder (if diffusers format) or model file (if A1111 format).\n", 78 | "
Set toggle accordingly.\n", 79 | "\n", 80 | "
" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": { 86 | "id": "dviyU3E6jK01" 87 | }, 88 | "source": [ 89 | "
\n", 90 | "Where to find urls\n", 91 | "\n", 92 | "A good place to find model urls is [camenduru colabs repo](https://https://github.com/camenduru/stable-diffusion-webui-colab), just look for a line starting with `!aria2c ...`\n", 93 | "\n", 94 | "
" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": { 101 | "cellView": "form", 102 | "id": "k_TR5dUwznN9" 103 | }, 104 | "outputs": [], 105 | "source": [ 106 | "#@title Render model choice UI\n", 107 | "\n", 108 | "from StableDiffusionUi.ColabUI.ColabWrapper import ColabWrapper\n", 109 | "\n", 110 | "colab = ColabWrapper(\"outputs/txt2img\")\n", 111 | "colab.render_model_index(\"/content/StableDiffusionUi/model_index.json\")" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": { 118 | "cellView": "form", 119 | "id": "IBdm3HvI02Yo" 120 | }, 121 | "outputs": [], 122 | "source": [ 123 | "#@title Load chosen SD model\n", 124 | "\n", 125 | "#TODO we have to use this until LPW properly supports clip skip\n", 126 | "from StableDiffusionUi.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline\n", 127 | "colab.load_model(StableDiffusionLongPromptWeightingPipeline, custom_pipeline=None)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "source": [ 133 | "#@title (*Optional*) Run this instead to load SDXL model\n", 134 | "from diffusers import StableDiffusionXLPipeline\n", 135 | "from StableDiffusionUi.lpw_stable_diffusion_xl import SDXLLongPromptWeightingPipeline\n", 136 | "\n", 137 | "colab.load_model(SDXLLongPromptWeightingPipeline, custom_pipeline=\"lpw_stable_diffusion_xl\")" 138 | ], 139 | "metadata": { 140 | "cellView": "form", 141 | "id": "AG4ChqboGDO_" 142 | }, 143 | "execution_count": null, 144 | "outputs": [] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": { 149 | "id": "pAJzkJWpeOxk" 150 | }, 151 | "source": [ 152 | "
\n", 153 | "Prompting detail\n", 154 | "\n", 155 | "I use [this](https://github.com/huggingface/diffusers/tree/main/examples/community#long-prompt-weighting-stable-diffusion) custom pipeline which doesn't have a token length limit and allows to use weights in prompt.\n", 156 | "\n", 157 | "
" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": { 163 | "id": "hLmzZmRIZMpC" 164 | }, 165 | "source": [ 166 | "# Generating images" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": { 173 | "cellView": "form", 174 | "id": "oFruEVcl4ODt" 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "#@title Render UI\n", 179 | "#@markdown ---\n", 180 | "from StableDiffusionUi.ColabUI.DiffusionPipelineUI import DiffusionPipelineUI\n", 181 | "\n", 182 | "ui = DiffusionPipelineUI()\n", 183 | "colab.render_generation_ui(ui)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": { 190 | "id": "Y43YLjY7F7NI", 191 | "cellView": "form" 192 | }, 193 | "outputs": [], 194 | "source": [ 195 | "#@title Run this cell to generate images\n", 196 | "colab.generate()" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": { 202 | "id": "gthYZKknKsw7" 203 | }, 204 | "source": [ 205 | "#Img2Img" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "source": [ 211 | "#@title Img2img refiner UI\n", 212 | "# We have to load custom pipeline ourselves because you can't pass it when reusing pipe components\n", 213 | "from diffusers.utils import get_class_from_dynamic_module\n", 214 | "from diffusers import AutoPipelineForImage2Image\n", 215 | "\n", 216 | "if colab.custom_pipeline is not None:\n", 217 | " Interface = get_class_from_dynamic_module(colab.custom_pipeline, module_file=f\"{colab.custom_pipeline}.py\")\n", 218 | "else:\n", 219 | " Interface = StableDiffusionLongPromptWeightingPipeline #AutoPipelineForImage2Image\n", 220 | "\n", 221 | "colab.render_refiner_ui(Interface)" 222 | ], 223 | "metadata": { 224 | "cellView": "form", 225 | "id": "d9oniqPLdZ4x" 226 | }, 227 | "execution_count": null, 228 | "outputs": [] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "source": [ 233 | "#@title Run img2img\n", 234 | "colab.refiner_generate(output_dir=\"outputs/img2img\")" 235 | ], 236 | "metadata": { 237 | "cellView": "form", 238 | "id": "zyk-JX0OeKvB" 239 | }, 240 | "execution_count": null, 241 | "outputs": [] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "source": [ 246 | "#Inpainting" 247 | ], 248 | "metadata": { 249 | "id": "o7aDqCk_EkwD" 250 | } 251 | }, 252 | { 253 | "cell_type": "code", 254 | "source": [ 255 | "#@title Setup\n", 256 | "#@markdown If UI doesn't show you may need to restart runtime. This stuff is experimantal and may not work properly.\n", 257 | "from google.colab import output\n", 258 | "output.enable_custom_widget_manager()\n", 259 | "\n", 260 | "!pip install ipycanvas" 261 | ], 262 | "metadata": { 263 | "cellView": "form", 264 | "id": "q0uVdOA7EnRj" 265 | }, 266 | "execution_count": null, 267 | "outputs": [] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "source": [ 272 | "#@title Inpainting UI\n", 273 | "#@markdown Click \"Save\" button before running inpaint.\n", 274 | "\n", 275 | "# We have to load custom pipeline ourselves because you can't pass it when reusing pipe components\n", 276 | "from diffusers.utils import get_class_from_dynamic_module\n", 277 | "from diffusers import AutoPipelineForInpainting\n", 278 | "from StableDiffusionUi.ColabUI.InpaintRefinerUI import InpaintRefinerUI\n", 279 | "\n", 280 | "if colab.custom_pipeline is not None:\n", 281 | " Interface = get_class_from_dynamic_module(colab.custom_pipeline, module_file=f\"{colab.custom_pipeline}.py\")\n", 282 | "else:\n", 283 | " Interface = StableDiffusionLongPromptWeightingPipeline #TODO actually need to check if we use XL or not\n", 284 | "\n", 285 | "inpaint_ui = InpaintRefinerUI(colab.ui)\n", 286 | "colab.render_inpaint_ui(Interface, inpaint_ui)" 287 | ], 288 | "metadata": { 289 | "cellView": "form", 290 | "id": "SGb1bthqEwvi" 291 | }, 292 | "execution_count": null, 293 | "outputs": [] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "source": [ 298 | "#@title Run inpaint\n", 299 | "colab.inpaint_generate(output_dir=\"outputs/img2img\")" 300 | ], 301 | "metadata": { 302 | "cellView": "form", 303 | "id": "zvi2DqRhFZry" 304 | }, 305 | "execution_count": null, 306 | "outputs": [] 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "source": [ 311 | "#Conversion" 312 | ], 313 | "metadata": { 314 | "id": "mHUsrD28erB4" 315 | } 316 | }, 317 | { 318 | "cell_type": "markdown", 319 | "metadata": { 320 | "id": "kRUTJXoBc4tG" 321 | }, 322 | "source": [ 323 | "## Converting models from ckpt\n", 324 | "You shouldn't really need this, because diffusers now support loading from A1111 format.\n", 325 | "\n", 326 | "But if you want, you can download a model and convert it to diffusers format." 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": { 332 | "id": "av9urfOqi7Ak" 333 | }, 334 | "source": [ 335 | "NOTE: A good place to find model urls is [camenduru colabs repo](https://https://github.com/camenduru/stable-diffusion-webui-colab), just look for a line starting with `!aria2c ...`\n" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "metadata": { 342 | "cellView": "form", 343 | "id": "w1JZPBKId6Zv" 344 | }, 345 | "outputs": [], 346 | "source": [ 347 | "#@markdown ##Download ckpt file\n", 348 | "#@markdown Download ckpt/safetensors file from huggingface url.\n", 349 | "\n", 350 | "#@markdown ---\n", 351 | "import os\n", 352 | "\n", 353 | "url = \"https://huggingface.co/mekabu/MagicalMix_v2/resolve/main/MagicalMix_v2.safetensors\" #@param {type:\"string\"}\n", 354 | "ckpt_dump_folder = \"/content/models_sd/\" #@param {type:\"string\"}\n", 355 | "model_name = url.split('/')[-1]\n", 356 | "\n", 357 | "os.system(f\"mkdir -p {ckpt_dump_folder}\")\n", 358 | "bashCommand = f\"wget {url} -P {ckpt_dump_folder} --content-disposition\"\n", 359 | "os.system(bashCommand)" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": null, 365 | "metadata": { 366 | "cellView": "form", 367 | "id": "dOn6tR_4eJ8k" 368 | }, 369 | "outputs": [], 370 | "source": [ 371 | "#@markdown ##(Optional) Or just upload file\n", 372 | "from google.colab import files\n", 373 | "f = files.upload()\n", 374 | "\n", 375 | "ckpt_dump_folder = \"/content/\"\n", 376 | "for key in f.keys():\n", 377 | " model_name = key\n", 378 | " break" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": { 385 | "cellView": "form", 386 | "id": "N1GOOJrej6tW" 387 | }, 388 | "outputs": [], 389 | "source": [ 390 | "#@markdown ##Run conversion script\n", 391 | "#@markdown Paste a path to where you want this script to cache a model into `dump_path`.\n", 392 | "\n", 393 | "#@markdown Keep `override_path` empty unless you uploaded your file to some custom directory.\n", 394 | "\n", 395 | "import os\n", 396 | "import torch\n", 397 | "from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt\n", 398 | "\n", 399 | "from_safetensors = True #@param {type:\"boolean\"}\n", 400 | "override_path = \"\" #@param {type:\"string\"}\n", 401 | "if override_path != \"\":\n", 402 | " checkpoint_path = override_path\n", 403 | "else:\n", 404 | " checkpoint_path = os.path.join(ckpt_dump_folder, model_name)\n", 405 | "\n", 406 | "pipe = download_from_original_stable_diffusion_ckpt(\n", 407 | " checkpoint_path=checkpoint_path,\n", 408 | " original_config_file = \"/content/v1-inference.yaml\",\n", 409 | " from_safetensors=from_safetensors,\n", 410 | " )\n", 411 | "\n", 412 | "dump_path=\"models/ModelName/\" #@param {type:\"string\"}\n", 413 | "pipe.save_pretrained(dump_path, safe_serialization=from_safetensors)\n", 414 | "\n", 415 | "pipe = pipe.to(\"cuda\")\n", 416 | "pipe.safety_checker = None\n", 417 | "pipe.to(torch_dtype=torch.float16)" 418 | ] 419 | }, 420 | { 421 | "cell_type": "markdown", 422 | "metadata": { 423 | "id": "ii9UDlWQnIhg" 424 | }, 425 | "source": [ 426 | "## Converting VAE" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": null, 432 | "metadata": { 433 | "cellView": "form", 434 | "id": "916cyCYPnNVt" 435 | }, 436 | "outputs": [], 437 | "source": [ 438 | "#@markdown ##Download vae file\n", 439 | "#@markdown Basically, the same thing as model files\n", 440 | "\n", 441 | "#@markdown ---\n", 442 | "import os\n", 443 | "\n", 444 | "url = \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt\" #@param {type:\"string\"}\n", 445 | "ckpt_dump_folder = \"/content/models_sd/\" #@param {type:\"string\"}\n", 446 | "model_name = url.split('/')[-1]\n", 447 | "\n", 448 | "os.system(f\"mkdir -p {ckpt_dump_folder}\")\n", 449 | "bashCommand = f\"wget {url} -P {ckpt_dump_folder} --content-disposition\"\n", 450 | "os.system(bashCommand)" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": null, 456 | "metadata": { 457 | "cellView": "form", 458 | "id": "dWC8UpJ3m8Fa" 459 | }, 460 | "outputs": [], 461 | "source": [ 462 | "#@markdown Paste a path to where you want this script to cache a vae into `dump_path`.\n", 463 | "from_safetensors = False #@param {type:\"boolean\"}\n", 464 | "dump_path=\"vae/VaeName/\" #@param {type:\"string\"}\n", 465 | "\n", 466 | "checkpoint_path = os.path.join(ckpt_dump_folder, model_name)\n", 467 | "bashCommand = f\"python /content/StableDiffusionUi/scripts/convert_vae_pt_to_diffusers.py --vae_pt_path {checkpoint_path} --dump_path {dump_path}\"\n", 468 | "if from_safetensors:\n", 469 | " bashCommand += \" --from_safetensors\"\n", 470 | "os.system(bashCommand)" 471 | ] 472 | }, 473 | { 474 | "cell_type": "markdown", 475 | "source": [ 476 | "#Flush model\n" 477 | ], 478 | "metadata": { 479 | "id": "33zsLV1d0CKE" 480 | } 481 | }, 482 | { 483 | "cell_type": "code", 484 | "source": [ 485 | "import gc\n", 486 | "import torch\n", 487 | "\n", 488 | "del colab.pipe\n", 489 | "torch.cuda.empty_cache()\n", 490 | "gc.collect()" 491 | ], 492 | "metadata": { 493 | "id": "aWqoBQP60OyY" 494 | }, 495 | "execution_count": null, 496 | "outputs": [] 497 | } 498 | ], 499 | "metadata": { 500 | "accelerator": "GPU", 501 | "colab": { 502 | "collapsed_sections": [ 503 | "gthYZKknKsw7", 504 | "o7aDqCk_EkwD", 505 | "mHUsrD28erB4", 506 | "kRUTJXoBc4tG", 507 | "ii9UDlWQnIhg", 508 | "33zsLV1d0CKE" 509 | ], 510 | "provenance": [], 511 | "include_colab_link": true 512 | }, 513 | "gpuClass": "standard", 514 | "kernelspec": { 515 | "display_name": "Python 3", 516 | "name": "python3" 517 | }, 518 | "language_info": { 519 | "name": "python" 520 | } 521 | }, 522 | "nbformat": 4, 523 | "nbformat_minor": 0 524 | } -------------------------------------------------------------------------------- /sd_diffusers_img2img_ui.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "cellView": "form", 18 | "id": "x-z4MGy9rz9I" 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "#@title #Install dependencies\n", 23 | "!git clone https://github.com/oneir0mancer/stable-diffusion-diffusers-colab-ui.git StableDiffusionUi\n", 24 | "\n", 25 | "!pip install -r StableDiffusionUi/requirements.txt\n", 26 | "!wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml\n", 27 | "\n", 28 | "!mkdir -p outputs/{txt2img,img2img}\n", 29 | "!git clone https://huggingface.co/embed/negative /content/embeddings/negative\n", 30 | "!git clone https://huggingface.co/embed/lora /content/Lora/positive\n" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": { 36 | "id": "QRvNp3ElcaA3" 37 | }, 38 | "source": [ 39 | "# Creating model pipeline" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": { 45 | "id": "dVlvLpShS9eo" 46 | }, 47 | "source": [ 48 | "\n", 49 | "
\n", 50 | "Description
TLDR: use dropdown or just paste
\n", 51 | "\n", 52 | "---\n", 53 | "\n", 54 | "You can just choose a model from a dropdown of popular models, which will also give you a link to a model page on [huggingface](https://huggingface.co) and any trigger words you need to use in the prompt.\n", 55 | "\n", 56 | "Alternatively, you can paste a model_id, url, or local path for any model from huggingface:\n", 57 | "\n", 58 | "* **model_id** looks like this: `gsdf/Counterfeit-V2.5`.\n", 59 | "
You can use it for any *proper* diffusers model in huggingface, i.e. ones that have a lot of folders in their **Files**, and a file called **model_index.json**\n", 60 | "
`A1111 format` toggle should be **off**.\n", 61 | "\n", 62 | "* **url** from HuggingFace may look like this:\n", 63 | "```\n", 64 | "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors\n", 65 | "```\n", 66 | "Use it when the model is in Automatic1111 format (as in, just a single file).\n", 67 | "
`✅ A1111 format` toggle should be **on**.\n", 68 | "\n", 69 | "* **url** from CivitAI should look like this:\n", 70 | "```\n", 71 | "https://civitai.com/api/download/models/XXXXXX\n", 72 | "```\n", 73 | "(the link you get from \"Download\" button)\n", 74 | "
`✅ A1111 format` toggle should be **on**.\n", 75 | "\n", 76 | "* **local path** is just a path to a model folder (if diffusers format) or model file (if A1111 format).\n", 77 | "
Set toggle accordingly.\n", 78 | "\n", 79 | "
" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": { 85 | "id": "dviyU3E6jK01" 86 | }, 87 | "source": [ 88 | "
\n", 89 | "Where to find urls\n", 90 | "\n", 91 | "A good place to find model urls is [camenduru colabs repo](https://https://github.com/camenduru/stable-diffusion-webui-colab), just look for a line starting with `!aria2c ...`\n", 92 | "\n", 93 | "
" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": { 100 | "cellView": "form", 101 | "id": "k_TR5dUwznN9" 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "#@title Render model choice UI\n", 106 | "from StableDiffusionUi.ColabUI.ColabWrapper import ColabWrapper\n", 107 | "\n", 108 | "colab = ColabWrapper(\"outputs/img2img\")\n", 109 | "colab.render_model_index(\"/content/StableDiffusionUi/model_index.json\")" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": { 116 | "cellView": "form", 117 | "id": "IBdm3HvI02Yo" 118 | }, 119 | "outputs": [], 120 | "source": [ 121 | "#@title Load chosen model\n", 122 | "from diffusers import StableDiffusionImg2ImgPipeline\n", 123 | "colab.load_model(StableDiffusionImg2ImgPipeline)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "source": [ 129 | "#@title Render Settings UI\n", 130 | "colab.render_settings()" 131 | ], 132 | "metadata": { 133 | "cellView": "form", 134 | "id": "Yz3c-nayq431" 135 | }, 136 | "execution_count": null, 137 | "outputs": [] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": { 142 | "id": "hLmzZmRIZMpC" 143 | }, 144 | "source": [ 145 | "# Generating images" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": { 152 | "cellView": "form", 153 | "id": "oFruEVcl4ODt" 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "#@title Render UI\n", 158 | "#@markdown You don't need to run this cell again unless you want to change these settings\n", 159 | "save_images = True #@param {type:\"boolean\"}\n", 160 | "display_previewes = True #@param {type:\"boolean\"}\n", 161 | "#@markdown ---\n", 162 | "from StableDiffusionUi.ColabUI.DiffusionImg2ImgUI import DiffusionImg2ImgUI\n", 163 | "\n", 164 | "ui = DiffusionImg2ImgUI()\n", 165 | "colab.render_generation_ui(ui)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": { 172 | "cellView": "form", 173 | "id": "Y43YLjY7F7NI" 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "#@title Run this cell to generate images\n", 178 | "colab.generate(save_images, display_previewes)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": { 184 | "id": "gthYZKknKsw7" 185 | }, 186 | "source": [ 187 | "#Utils" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": { 194 | "cellView": "form", 195 | "id": "9vn2Z5tNIAiv" 196 | }, 197 | "outputs": [], 198 | "source": [ 199 | "#@markdown ## Image viewer\n", 200 | "#@markdown Run this cell to view last results in full size\n", 201 | "from IPython.display import clear_output\n", 202 | "import ipywidgets as widgets\n", 203 | "\n", 204 | "slider = widgets.IntSlider(max=len(results.images)-1)\n", 205 | "\n", 206 | "def handler(change):\n", 207 | " slider.max = len(results.images)-1\n", 208 | " if change.new > slider.max: change.new = slider.max\n", 209 | " clear_output(wait=True)\n", 210 | " display(slider, results.images[change.new])\n", 211 | "\n", 212 | "slider.observe(handler, names='value')\n", 213 | "\n", 214 | "display(slider, results.images[slider.value])" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": { 220 | "id": "kRUTJXoBc4tG" 221 | }, 222 | "source": [ 223 | "## Converting models from ckpt\n", 224 | "You shouldn't really need this, because diffusers now support loading from A1111 format.\n", 225 | "\n", 226 | "But if you want, you can download a model and convert it to diffusers format." 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": { 232 | "id": "av9urfOqi7Ak" 233 | }, 234 | "source": [ 235 | "NOTE: A good place to find model urls is [camenduru colabs repo](https://https://github.com/camenduru/stable-diffusion-webui-colab), just look for a line starting with `!aria2c ...`\n" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": { 242 | "cellView": "form", 243 | "id": "w1JZPBKId6Zv" 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "#@markdown ##Download ckpt file\n", 248 | "#@markdown Download ckpt/safetensors file from huggingface url.\n", 249 | "\n", 250 | "#@markdown ---\n", 251 | "import os\n", 252 | "\n", 253 | "url = \"https://huggingface.co/mekabu/MagicalMix_v2/resolve/main/MagicalMix_v2.safetensors\" #@param {type:\"string\"}\n", 254 | "ckpt_dump_folder = \"/content/models_sd/\" #@param {type:\"string\"}\n", 255 | "model_name = url.split('/')[-1]\n", 256 | "\n", 257 | "os.system(f\"mkdir -p {ckpt_dump_folder}\")\n", 258 | "bashCommand = f\"wget {url} -P {ckpt_dump_folder} --content-disposition\"\n", 259 | "os.system(bashCommand)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": { 266 | "cellView": "form", 267 | "id": "dOn6tR_4eJ8k" 268 | }, 269 | "outputs": [], 270 | "source": [ 271 | "#@markdown ##(Optional) Or just upload file\n", 272 | "from google.colab import files\n", 273 | "f = files.upload()\n", 274 | "\n", 275 | "ckpt_dump_folder = \"/content/\"\n", 276 | "for key in f.keys():\n", 277 | " model_name = key\n", 278 | " break" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": { 285 | "cellView": "form", 286 | "id": "N1GOOJrej6tW" 287 | }, 288 | "outputs": [], 289 | "source": [ 290 | "#@markdown ##Run conversion script\n", 291 | "#@markdown Paste a path to where you want this script to cache a model into `dump_path`.\n", 292 | "\n", 293 | "#@markdown Keep `override_path` empty unless you uploaded your file to some custom directory.\n", 294 | "\n", 295 | "import os\n", 296 | "import torch\n", 297 | "from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt\n", 298 | "\n", 299 | "from_safetensors = True #@param {type:\"boolean\"}\n", 300 | "override_path = \"\" #@param {type:\"string\"}\n", 301 | "if override_path != \"\":\n", 302 | " checkpoint_path = override_path\n", 303 | "else:\n", 304 | " checkpoint_path = os.path.join(ckpt_dump_folder, model_name)\n", 305 | "\n", 306 | "pipe = download_from_original_stable_diffusion_ckpt(\n", 307 | " checkpoint_path=checkpoint_path,\n", 308 | " original_config_file = \"/content/v1-inference.yaml\",\n", 309 | " from_safetensors=from_safetensors,\n", 310 | " )\n", 311 | "\n", 312 | "dump_path=\"models/ModelName/\" #@param {type:\"string\"}\n", 313 | "pipe.save_pretrained(dump_path, safe_serialization=from_safetensors)\n", 314 | "\n", 315 | "pipe = pipe.to(\"cuda\")\n", 316 | "pipe.safety_checker = None\n", 317 | "pipe.to(torch_dtype=torch.float16)" 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": { 323 | "id": "ii9UDlWQnIhg" 324 | }, 325 | "source": [ 326 | "## Converting VAE" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": null, 332 | "metadata": { 333 | "cellView": "form", 334 | "id": "916cyCYPnNVt" 335 | }, 336 | "outputs": [], 337 | "source": [ 338 | "#@markdown ##Download vae file\n", 339 | "#@markdown Basically, the same thing as model files\n", 340 | "\n", 341 | "#@markdown ---\n", 342 | "import os\n", 343 | "\n", 344 | "url = \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt\" #@param {type:\"string\"}\n", 345 | "ckpt_dump_folder = \"/content/models_sd/\" #@param {type:\"string\"}\n", 346 | "model_name = url.split('/')[-1]\n", 347 | "\n", 348 | "os.system(f\"mkdir -p {ckpt_dump_folder}\")\n", 349 | "bashCommand = f\"wget {url} -P {ckpt_dump_folder} --content-disposition\"\n", 350 | "os.system(bashCommand)" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "metadata": { 357 | "cellView": "form", 358 | "id": "dWC8UpJ3m8Fa" 359 | }, 360 | "outputs": [], 361 | "source": [ 362 | "#@markdown Paste a path to where you want this script to cache a vae into `dump_path`.\n", 363 | "from_safetensors = False #@param {type:\"boolean\"}\n", 364 | "dump_path=\"vae/VaeName/\" #@param {type:\"string\"}\n", 365 | "\n", 366 | "checkpoint_path = os.path.join(ckpt_dump_folder, model_name)\n", 367 | "bashCommand = f\"python /content/StableDiffusionUi/scripts/convert_vae_pt_to_diffusers.py --vae_pt_path {checkpoint_path} --dump_path {dump_path}\"\n", 368 | "if from_safetensors:\n", 369 | " bashCommand += \" --from_safetensors\"\n", 370 | "os.system(bashCommand)" 371 | ] 372 | } 373 | ], 374 | "metadata": { 375 | "accelerator": "GPU", 376 | "colab": { 377 | "collapsed_sections": [ 378 | "hEwlf_SQXCyt", 379 | "mY0xKEW2lgmj", 380 | "lYurBnHCbZ3o", 381 | "kRUTJXoBc4tG", 382 | "ii9UDlWQnIhg" 383 | ], 384 | "provenance": [], 385 | "include_colab_link": true 386 | }, 387 | "gpuClass": "standard", 388 | "kernelspec": { 389 | "display_name": "Python 3", 390 | "name": "python3" 391 | }, 392 | "language_info": { 393 | "name": "python" 394 | } 395 | }, 396 | "nbformat": 4, 397 | "nbformat_minor": 0 398 | } -------------------------------------------------------------------------------- /sd_kandinsky_colab_ui.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "toc_visible": true, 8 | "authorship_tag": "ABX9TyNG2pl8jKc8G9gJx57zyOjk", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "view-in-github", 24 | "colab_type": "text" 25 | }, 26 | "source": [ 27 | "\"Open" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "source": [ 33 | "# Kandinsky" 34 | ], 35 | "metadata": { 36 | "id": "J5jPbraGhH-D" 37 | } 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": { 43 | "cellView": "form", 44 | "id": "ayR-stkCUiZh" 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "#@title Install dependencies\n", 49 | "!pip install 'git+https://github.com/ai-forever/Kandinsky-2.git'\n", 50 | "!pip install git+https://github.com/openai/CLIP.git\n", 51 | "!mkdir -p outputs/{txt2img,img2img}\n", 52 | "!git clone https://github.com/oneir0mancer/stable-diffusion-diffusers-colab-ui.git StableDiffusionUi" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "source": [ 58 | "#@title Create model\n", 59 | "from kandinsky2 import get_kandinsky2\n", 60 | "\n", 61 | "model = get_kandinsky2('cuda', task_type='text2img', cache_dir='/tmp/kandinsky2', \n", 62 | " model_version='2.1', use_flash_attention=False)\n", 63 | "sampler = \"p_sampler\"" 64 | ], 65 | "metadata": { 66 | "cellView": "form", 67 | "id": "05ppJRuPfw56" 68 | }, 69 | "execution_count": null, 70 | "outputs": [] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "source": [ 75 | "#@title (Optional) Change sampler \n", 76 | "sampler = \"p_sampler\" #@param [\"p_sampler\", \"ddim_sampler\", \"plms_sampler\"]" 77 | ], 78 | "metadata": { 79 | "cellView": "form", 80 | "id": "OxmVn8OLf1al" 81 | }, 82 | "execution_count": null, 83 | "outputs": [] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "source": [ 88 | "## Txt2Img" 89 | ], 90 | "metadata": { 91 | "id": "oy0VXcG8iZlj" 92 | } 93 | }, 94 | { 95 | "cell_type": "code", 96 | "source": [ 97 | "#@title Render UI\n", 98 | "#@markdown You don't need to run this cell again unless you want to change these settings\n", 99 | "save_images = True #@param {type:\"boolean\"}\n", 100 | "display_previewes = True #@param {type:\"boolean\"}\n", 101 | "#@markdown ---\n", 102 | "\n", 103 | "from StableDiffusionUi.ColabUI.KandinskyUI import KandinskyUI\n", 104 | "\n", 105 | "ui = KandinskyUI()\n", 106 | "ui.render()" 107 | ], 108 | "metadata": { 109 | "cellView": "form", 110 | "id": "zFlSy38ef3iV" 111 | }, 112 | "execution_count": null, 113 | "outputs": [] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "source": [ 118 | "#@title Run this cell to generate images\n", 119 | "images = ui.generate(model, sampler)\n", 120 | "\n", 121 | "if display_previewes:\n", 122 | " ui.display_image_previews(images)\n", 123 | "\n", 124 | "if save_images:\n", 125 | " for image in results.images:\n", 126 | " image.save(f\"outputs/txt2img/{output_index:05}.png\")\n", 127 | " print(f\"outputs/txt2img/{output_index:05}.png\")\n", 128 | " output_index += 1" 129 | ], 130 | "metadata": { 131 | "cellView": "form", 132 | "id": "z5dCbkQ0f9KV" 133 | }, 134 | "execution_count": null, 135 | "outputs": [] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "source": [ 140 | "## Image mixing\n", 141 | "(WIP)" 142 | ], 143 | "metadata": { 144 | "id": "wQPK2SDpic70" 145 | } 146 | }, 147 | { 148 | "cell_type": "code", 149 | "source": [ 150 | "#@title Render UI\n", 151 | "import io\n", 152 | "import ipywidgets as widgets\n", 153 | "from PIL import Image\n", 154 | "\n", 155 | "text_img_a = widgets.Text(placeholder=\"Path to 1st image\")\n", 156 | "text_img_b = widgets.Text(placeholder=\"Path to 2nd image\")\n", 157 | "slider_img_a = widgets.FloatSlider(max=1, step=0.05)\n", 158 | "slider_img_b = widgets.FloatSlider(max=1, step=0.05)\n", 159 | "text_prompt_a = widgets.Text(placeholder=\"Optional 1st prompt\")\n", 160 | "text_prompt_b = widgets.Text(placeholder=\"Optional 2nd prompt\")\n", 161 | "slider_prompt_a = widgets.FloatSlider(max=1, step=0.05)\n", 162 | "slider_prompt_b = widgets.FloatSlider(max=1, step=0.05)\n", 163 | "\n", 164 | "width_field = widgets.IntText(value=512, layout=widgets.Layout(width='200px'), description=\"width: \")\n", 165 | "height_field = widgets.IntText(value=768, layout=widgets.Layout(width='200px'), description=\"height: \")\n", 166 | "steps_field = widgets.IntSlider(value=20, min=1, max=100, description=\"Steps: \")\n", 167 | "cfg_field = widgets.FloatSlider(value=7, min=1, max=20, step=0.5, description=\"CFG: \")\n", 168 | "\n", 169 | "display(widgets.HBox([text_img_a, text_img_b]), widgets.HBox([slider_img_a, slider_img_b]))\n", 170 | "display(widgets.HBox([text_prompt_a, text_prompt_b]), widgets.HBox([slider_prompt_a, slider_prompt_b]))\n", 171 | "display(widgets.HBox([width_field, height_field]), steps_field, cfg_field)" 172 | ], 173 | "metadata": { 174 | "cellView": "form", 175 | "id": "csc--queiiNE" 176 | }, 177 | "execution_count": null, 178 | "outputs": [] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "source": [ 183 | "#@title Run this cell to generate\n", 184 | "def display_image_previews(images):\n", 185 | " wgts = []\n", 186 | " for image in images:\n", 187 | " img_byte_arr = io.BytesIO()\n", 188 | " image.save(img_byte_arr, format='PNG')\n", 189 | " wgt = widgets.Image(\n", 190 | " value=img_byte_arr.getvalue(),\n", 191 | " width=image.width/2,\n", 192 | " height=image.height/2,\n", 193 | " )\n", 194 | " wgts.append(wgt)\n", 195 | " display(widgets.HBox(wgts))\n", 196 | "\n", 197 | "img_a = Image.open(text_img_a.value)\n", 198 | "img_b = Image.open(text_img_b.value)\n", 199 | "images_texts = [text_prompt_a.value, img_a, img_b, text_prompt_b.value]\n", 200 | "weights = [slider_prompt_a.value, slider_img_a.value, slider_img_b.value, slider_prompt_b.value]\n", 201 | "\n", 202 | "images = model.mix_images(\n", 203 | " images_texts, \n", 204 | " weights, \n", 205 | " num_steps=steps_field.value,\n", 206 | " batch_size=1, \n", 207 | " guidance_scale=cfg_field.value,\n", 208 | " h=height_field.value, w=width_field.value,\n", 209 | " sampler=sampler, \n", 210 | " prior_cf_scale=4,\n", 211 | " prior_steps=\"5\"\n", 212 | ")\n", 213 | "display_image_previews([img_a, images[0], img_b])" 214 | ], 215 | "metadata": { 216 | "cellView": "form", 217 | "id": "bz0Ty0b4nA2E" 218 | }, 219 | "execution_count": null, 220 | "outputs": [] 221 | } 222 | ] 223 | } -------------------------------------------------------------------------------- /utils/downloader.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import werkzeug 3 | import os 4 | from tqdm.notebook import tqdm 5 | 6 | def download_ckpt(ckpt_url, output_dir=None): 7 | """ 8 | download the checkpoint & return file name to be used in diffusers pipeline 9 | additional actions: 10 | - try to get file name from header or url 11 | - download only if file not exist 12 | - download with progress bar 13 | """ 14 | HEADERS = {"User-Agent": "github/oneir0mancer/stable-diffusion-diffusers-colab-ui"} 15 | with requests.get(ckpt_url, headers=HEADERS, stream=True) as resp: 16 | 17 | # get file name 18 | MISSING_FILENAME = "missing_name" 19 | if content_disposition := resp.headers.get("Content-Disposition"): 20 | param, options = werkzeug.http.parse_options_header(content_disposition) 21 | if param == "attachment": 22 | filename = options.get("filename", MISSING_FILENAME) 23 | else: 24 | filename = MISSING_FILENAME 25 | else: 26 | filename = os.path.basename(ckpt_url) 27 | fileext = os.path.splitext(filename)[-1] 28 | if fileext == "": 29 | filename = MISSING_FILENAME 30 | 31 | if output_dir is not None: filename = os.path.join(output_dir, filename) 32 | 33 | # download file 34 | if not os.path.exists(filename): 35 | TOTAL_SIZE = int(resp.headers.get("Content-Length", 0)) 36 | CHUNK_SIZE = 50 * 1024**2 # 50 MiB or more as colab has high speed internet 37 | with open(filename, mode="wb") as file, tqdm(total=TOTAL_SIZE, desc="download checkpoint", unit="iB", unit_scale=True) as progress_bar: 38 | for chunk in resp.iter_content(chunk_size=CHUNK_SIZE): 39 | progress_bar.update(len(chunk)) 40 | file.write(chunk) 41 | return filename 42 | -------------------------------------------------------------------------------- /utils/empty_output.py: -------------------------------------------------------------------------------- 1 | class EmptyOutput: 2 | """An empty wrapper around output, for when Output widget wasn't provided""" 3 | def __enter__(self): pass 4 | 5 | def __exit__(self, type, value, traceback): pass 6 | 7 | def clear_output(self): pass -------------------------------------------------------------------------------- /utils/event.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | class Event: 4 | def __init__(self): 5 | self.__callbacks = [] 6 | 7 | def add_callback(self, callback: Callable): 8 | self.__callbacks.append(callback) 9 | 10 | def remove_callback(self, callback: Callable): 11 | self.__callbacks.remove(callback) 12 | 13 | def clear_callbacks(self): 14 | self.__callbacks.clear() 15 | 16 | def invoke(self, event_args: list = None): 17 | if event_args is None: 18 | for x in self.__callbacks: x() 19 | else: 20 | for x in self.__callbacks: x(event_args) 21 | 22 | @property 23 | def callbacks(self): return self.__callbacks 24 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from PIL.PngImagePlugin import PngInfo 3 | 4 | def load_image_metadata(filepath: str): 5 | image = Image.open(filepath) 6 | prompt, etc = image.text["Data"].split("\nNegative: ") 7 | prompt = prompt.replace("\nPrompt: ", "") 8 | negative_prompt, etc = etc.split("\nCGF: ") 9 | # TODO parse other metadata into dict perhabs 10 | return prompt, negative_prompt 11 | 12 | def save_image_with_metadata(image, path, metadata_provider, additional_data = ""): 13 | meta = PngInfo() 14 | meta.add_text("Data", metadata_provider.metadata + additional_data) 15 | image.save(path, pnginfo=meta) 16 | -------------------------------------------------------------------------------- /utils/kohya_lora_loader.py: -------------------------------------------------------------------------------- 1 | #from: https://gist.github.com/takuma104/e38d683d72b1e448b8d9b3835f7cfa44 2 | import math 3 | import safetensors 4 | import torch 5 | from diffusers import DiffusionPipeline 6 | 7 | """ 8 | Kohya's LoRA format Loader for Diffusers 9 | Usage: 10 | ```py 11 | # An usual Diffusers' setup 12 | import torch 13 | from diffusers import StableDiffusionPipeline 14 | pipe = StableDiffusionPipeline.from_pretrained('...', 15 | torch_dtype=torch.float16).to('cuda') 16 | # Import this module 17 | import kohya_lora_loader 18 | # Install LoRA hook. This append apply_loar and remove_loar methods to the pipe. 19 | kohya_lora_loader.install_lora_hook(pipe) 20 | # Load 'lora1.safetensors' file and apply 21 | lora1 = pipe.apply_lora('lora1.safetensors', 1.0) 22 | # You can change alpha 23 | lora1.alpha = 0.5 24 | # Load 'lora2.safetensors' file and apply 25 | lora2 = pipe.apply_lora('lora2.safetensors', 1.0) 26 | # Generate image with lora1 and lora2 applied 27 | pipe(...).images[0] 28 | # Remove lora2 29 | pipe.remove_lora(lora2) 30 | # Generate image with lora1 applied 31 | pipe(...).images[0] 32 | # Uninstall LoRA hook 33 | kohya_lora_loader.uninstall_lora_hook(pipe) 34 | # Generate image with none LoRA applied 35 | pipe(...).images[0] 36 | ``` 37 | """ 38 | 39 | 40 | # modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17 41 | class LoRAModule(torch.nn.Module): 42 | def __init__( 43 | self, org_module: torch.nn.Module, lora_dim=4, alpha=1.0, multiplier=1.0 44 | ): 45 | """if alpha == 0 or None, alpha is rank (no scaling).""" 46 | super().__init__() 47 | 48 | if org_module.__class__.__name__ == "Conv2d": 49 | in_dim = org_module.in_channels 50 | out_dim = org_module.out_channels 51 | else: 52 | in_dim = org_module.in_features 53 | out_dim = org_module.out_features 54 | 55 | self.lora_dim = lora_dim 56 | 57 | if org_module.__class__.__name__ == "Conv2d": 58 | kernel_size = org_module.kernel_size 59 | stride = org_module.stride 60 | padding = org_module.padding 61 | self.lora_down = torch.nn.Conv2d( 62 | in_dim, self.lora_dim, kernel_size, stride, padding, bias=False 63 | ) 64 | self.lora_up = torch.nn.Conv2d( 65 | self.lora_dim, out_dim, (1, 1), (1, 1), bias=False 66 | ) 67 | else: 68 | self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) 69 | self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) 70 | 71 | if alpha is None or alpha == 0: 72 | self.alpha = self.lora_dim 73 | else: 74 | if type(alpha) == torch.Tensor: 75 | alpha = alpha.detach().float().numpy() # without casting, bf16 causes error 76 | self.register_buffer("alpha", torch.tensor(alpha)) # Treatable as a constant. 77 | 78 | # same as microsoft's 79 | torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) 80 | torch.nn.init.zeros_(self.lora_up.weight) 81 | 82 | self.multiplier = multiplier 83 | 84 | def forward(self, x): 85 | scale = self.alpha / self.lora_dim 86 | return self.multiplier * scale * self.lora_up(self.lora_down(x)) 87 | 88 | 89 | class LoRAModuleContainer(torch.nn.Module): 90 | def __init__(self, hooks, state_dict, multiplier): 91 | super().__init__() 92 | self.multiplier = multiplier 93 | 94 | # Create LoRAModule from state_dict information 95 | for key, value in state_dict.items(): 96 | if "lora_down" in key: 97 | lora_name = key.split(".")[0] 98 | lora_dim = value.size()[0] 99 | lora_name_alpha = key.split(".")[0] + '.alpha' 100 | alpha = None 101 | if lora_name_alpha in state_dict: 102 | alpha = state_dict[lora_name_alpha].item() 103 | hook = hooks[lora_name] 104 | lora_module = LoRAModule( 105 | hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier 106 | ) 107 | self.register_module(lora_name, lora_module) 108 | 109 | # Load whole LoRA weights 110 | self.load_state_dict(state_dict) 111 | 112 | # Register LoRAModule to LoRAHook 113 | for name, module in self.named_modules(): 114 | if module.__class__.__name__ == "LoRAModule": 115 | hook = hooks[name] 116 | hook.append_lora(module) 117 | @property 118 | def alpha(self): 119 | return self.multiplier 120 | 121 | @alpha.setter 122 | def alpha(self, multiplier): 123 | self.multiplier = multiplier 124 | for name, module in self.named_modules(): 125 | if module.__class__.__name__ == "LoRAModule": 126 | module.multiplier = multiplier 127 | 128 | def remove_from_hooks(self, hooks): 129 | for name, module in self.named_modules(): 130 | if module.__class__.__name__ == "LoRAModule": 131 | hook = hooks[name] 132 | hook.remove_lora(module) 133 | del module 134 | 135 | 136 | class LoRAHook(torch.nn.Module): 137 | """ 138 | replaces forward method of the original Linear, 139 | instead of replacing the original Linear module. 140 | """ 141 | 142 | def __init__(self): 143 | super().__init__() 144 | self.lora_modules = [] 145 | 146 | def install(self, orig_module): 147 | assert not hasattr(self, "orig_module") 148 | self.orig_module = orig_module 149 | self.orig_forward = self.orig_module.forward 150 | self.orig_module.forward = self.forward 151 | 152 | def uninstall(self): 153 | assert hasattr(self, "orig_module") 154 | self.orig_module.forward = self.orig_forward 155 | del self.orig_forward 156 | del self.orig_module 157 | 158 | def append_lora(self, lora_module): 159 | self.lora_modules.append(lora_module) 160 | 161 | def remove_lora(self, lora_module): 162 | self.lora_modules.remove(lora_module) 163 | 164 | def forward(self, x): 165 | if len(self.lora_modules) == 0: 166 | return self.orig_forward(x) 167 | lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0) 168 | return self.orig_forward(x) + lora 169 | 170 | 171 | class LoRAHookInjector(object): 172 | def __init__(self): 173 | super().__init__() 174 | self.hooks = {} 175 | self.device = None 176 | self.dtype = None 177 | 178 | def _get_target_modules(self, root_module, prefix, target_replace_modules): 179 | target_modules = [] 180 | for name, module in root_module.named_modules(): 181 | if ( 182 | module.__class__.__name__ in target_replace_modules 183 | and not "transformer_blocks" in name 184 | ): # to adapt latest diffusers: 185 | for child_name, child_module in module.named_modules(): 186 | is_linear = child_module.__class__.__name__ == "Linear" 187 | is_conv2d = child_module.__class__.__name__ == "Conv2d" 188 | if is_linear or is_conv2d: 189 | lora_name = prefix + "." + name + "." + child_name 190 | lora_name = lora_name.replace(".", "_") 191 | target_modules.append((lora_name, child_module)) 192 | return target_modules 193 | 194 | def install_hooks(self, pipe): 195 | """Install LoRAHook to the pipe.""" 196 | assert len(self.hooks) == 0 197 | text_encoder_targets = self._get_target_modules( 198 | pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"] 199 | ) 200 | unet_targets = self._get_target_modules( 201 | pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"] 202 | ) 203 | for name, target_module in text_encoder_targets + unet_targets: 204 | hook = LoRAHook() 205 | hook.install(target_module) 206 | self.hooks[name] = hook 207 | 208 | self.device = pipe.device 209 | self.dtype = pipe.unet.dtype 210 | 211 | def uninstall_hooks(self): 212 | """Uninstall LoRAHook from the pipe.""" 213 | for k, v in self.hooks.items(): 214 | v.uninstall() 215 | self.hooks = {} 216 | 217 | def apply_lora(self, filename, alpha=1.0): 218 | """Load LoRA weights and apply LoRA to the pipe.""" 219 | assert len(self.hooks) != 0 220 | state_dict = safetensors.torch.load_file(filename) 221 | container = LoRAModuleContainer(self.hooks, state_dict, alpha) 222 | container.to(self.device, self.dtype) 223 | return container 224 | 225 | def remove_lora(self, container): 226 | """Remove the individual LoRA from the pipe.""" 227 | container.remove_from_hooks(self.hooks) 228 | 229 | 230 | def install_lora_hook(pipe: DiffusionPipeline): 231 | """Install LoRAHook to the pipe.""" 232 | assert not hasattr(pipe, "lora_injector") 233 | assert not hasattr(pipe, "apply_lora") 234 | assert not hasattr(pipe, "remove_lora") 235 | injector = LoRAHookInjector() 236 | injector.install_hooks(pipe) 237 | pipe.lora_injector = injector 238 | pipe.apply_lora = injector.apply_lora 239 | pipe.remove_lora = injector.remove_lora 240 | 241 | 242 | def uninstall_lora_hook(pipe: DiffusionPipeline): 243 | """Uninstall LoRAHook from the pipe.""" 244 | pipe.lora_injector.uninstall_hooks() 245 | del pipe.lora_injector 246 | del pipe.apply_lora 247 | del pipe.remove_lora -------------------------------------------------------------------------------- /utils/markdown.py: -------------------------------------------------------------------------------- 1 | import markdown 2 | from ipywidgets import HTML 3 | 4 | def SpoilerLabel(warning:str, spoiler_text:str): 5 | md = f""" 6 |
7 | {warning} 8 | {spoiler_text} 9 |
10 | """ 11 | return HTML(markdown.markdown(md)) 12 | 13 | def MarkdownLabel(text:str): 14 | return HTML(markdown.markdown(text)) 15 | -------------------------------------------------------------------------------- /utils/preview.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from ipywidgets import Image, HBox, VBox, Button, Layout 4 | from typing import List 5 | 6 | def get_image_previews(paths: List[str], width: int, height: int, 7 | favourite_dir: str = "outputs/favourite") -> HBox: 8 | """Creates a widget preview for every image in paths""" 9 | wgts = [] 10 | for filepath in paths: 11 | preview = Image( 12 | value=open(filepath, "rb").read(), 13 | width=width, 14 | height=height, 15 | ) 16 | btn = _add_favourite_button(filepath, favourite_dir) 17 | wgt = VBox([preview, btn]) 18 | wgts.append(wgt) 19 | 20 | return HBox(wgts) 21 | 22 | def copy_file(file_path: str, to_path: str): 23 | """Copy file at `file_path` to `to_path` directory""" 24 | try_create_dir(to_path) 25 | filename = os.path.basename(file_path) 26 | shutil.copyfile(file_path, os.path.join(to_path, filename)) 27 | 28 | def try_create_dir(path: str): 29 | """Create folder at `path` if it doesn't already exist""" 30 | try: 31 | os.makedirs(path) 32 | except FileExistsError: pass 33 | 34 | def _add_favourite_button(filepath: str, to_dir:str) -> Button: 35 | btn = Button(description="Favourite", tooltip=filepath, layout=Layout(width="99%")) 36 | def btn_handler(b): 37 | copy_file(filepath, to_dir) 38 | btn.on_click(btn_handler) 39 | return btn -------------------------------------------------------------------------------- /utils/ui.py: -------------------------------------------------------------------------------- 1 | from ipywidgets import Button, Layout 2 | 3 | class SwitchButton: 4 | def __init__(self, description_on_switch, description_off_switch, layout, 5 | callback_on, callback_off, start_on=False): 6 | self.description_on_switch = description_on_switch 7 | self.description_off_switch = description_off_switch 8 | self.callback_on = callback_on 9 | self.callback_off = callback_off 10 | self.state = start_on 11 | 12 | self.button = Button(description="", layout=layout) 13 | if start_on: self.__set_on_state() 14 | else: self.__set_off_state() 15 | 16 | def switch_on(self, b): 17 | if self.state is True: return 18 | self.callback_on() 19 | self.__set_on_state() 20 | 21 | def switch_off(self, b): 22 | if self.state is False: return 23 | self.callback_off() 24 | self.__set_off_state() 25 | 26 | def __set_on_state(self): 27 | self.button.on_click(self.switch_on, True) 28 | self.button.on_click(self.switch_off) 29 | self.button.description = self.description_off_switch 30 | self.button.button_style = "success" 31 | self.state = True 32 | 33 | def __set_off_state(self): 34 | self.button.on_click(self.switch_off, True) 35 | self.button.on_click(self.switch_on) 36 | self.button.description = self.description_on_switch 37 | self.button.button_style = "danger" 38 | self.state = False 39 | 40 | @property 41 | def render_element(self): 42 | return self.button 43 | 44 | def render(self): 45 | display(self.render_element) 46 | 47 | def _ipython_display_(self): 48 | self.render() 49 | --------------------------------------------------------------------------------