├── .gitignore
├── ColabUI
├── ArtistIndex.py
├── BaseUI.py
├── CanvasHolder.py
├── ColabWrapper.py
├── DiffusionImg2ImgUI.py
├── DiffusionPipelineUI.py
├── FlairView.py
├── HugginfaceModelIndex.py
├── ImageViewer.py
├── Img2ImgRefinerUI.py
├── InpaintRefinerUI.py
├── KandinskyUI.py
├── LoraApplyer.py
├── LoraChoice.py
├── LoraDownloader.py
├── SamplerChoice.py
├── SettingsTabs.py
├── TextualInversionChoice.py
└── VaeChoice.py
├── LICENSE
├── README.md
├── artists.txt
├── docs
└── ui-example.jpg
├── lpw_stable_diffusion.py
├── model_index.json
├── requirements.txt
├── scripts
├── convert_original_stable_diffusion_to_diffusers.py
└── convert_vae_pt_to_diffusers.py
├── sd_diffusers_colab_ui.ipynb
├── sd_diffusers_img2img_ui.ipynb
├── sd_kandinsky_colab_ui.ipynb
└── utils
├── downloader.py
├── empty_output.py
├── event.py
├── image_utils.py
├── kohya_lora_loader.py
├── markdown.py
├── preview.py
└── ui.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/ColabUI/ArtistIndex.py:
--------------------------------------------------------------------------------
1 | from IPython.display import display, Javascript
2 | from ipywidgets import Dropdown, HTML, HBox, Layout, Text, Button, Accordion, Output
3 | import json
4 |
5 | class ArtistIndex:
6 | __clipboard: str
7 |
8 | def __init__(self, ui, index_path = "artist_index.json"):
9 | with open(index_path) as f:
10 | self.data = json.load(f)
11 |
12 | self.ui = ui
13 | self.example_link = HTML()
14 |
15 | self.__setup_artist_dropdown()
16 | self.__setup_buttons()
17 | self.__setup_foldout()
18 |
19 | def render(self):
20 | """Display ui"""
21 | self.__set_link_from_item(self.artist_dropdown.value)
22 | display(self.foldout, self.output)
23 |
24 | def __setup_artist_dropdown(self):
25 | self.artist_dropdown = Dropdown(
26 | options=[x for x in self.data],
27 | description="Artist:",
28 | )
29 | self.artist_dropdown.description_tooltip = "Choose artist flair tag to add"
30 | def dropdown_eventhandler(change):
31 | self.__set_link_from_item(change.new)
32 | self.artist_dropdown.observe(dropdown_eventhandler, names='value')
33 |
34 | def __setup_buttons(self):
35 | self.add_button = Button(description="Add", layout=Layout(width='80px'))
36 | def on_add_clicked(b):
37 | self.ui.positive_prompt.value += f", by {self.artist_dropdown.value}"
38 | self.add_button.on_click(on_add_clicked)
39 |
40 | self.copy_button = Button(description='Copy', tooltip="Copy to clipboard", layout=Layout(width='80px'))
41 | self.output = Output()
42 | def copy_event_handler(b):
43 | with self.output:
44 | text = self.__clipboard
45 | display(Javascript(f"navigator.clipboard.writeText('{text}');"))
46 | self.copy_button.on_click(copy_event_handler)
47 |
48 | def __setup_foldout(self):
49 | hbox = HBox([self.artist_dropdown, self.add_button, self.copy_button, self.example_link])
50 | self.foldout = Accordion(children=[hbox])
51 | self.foldout.set_title(0, "Flair tags")
52 | self.foldout.selected_index = None
53 |
54 | def __set_link_from_item(self, item):
55 | self.__clipboard = item
56 | key = item.lower().replace(" ", "-")
57 | self.example_link.value = f"Example: {item}"
58 |
59 | def _ipython_display_(self):
60 | self.render()
61 |
--------------------------------------------------------------------------------
/ColabUI/BaseUI.py:
--------------------------------------------------------------------------------
1 | from IPython.display import display
2 | from ipywidgets import Textarea, IntText, IntSlider, FloatSlider, HBox, Layout, Label, Button
3 | from PIL.PngImagePlugin import PngInfo
4 | import ipywidgets as widgets
5 | import io
6 |
7 | class BaseUI:
8 | _metadata :str
9 |
10 | def __init__(self):
11 | self.positive_prompt = Textarea(placeholder='Positive prompt...', layout=Layout(width="50%"))
12 | self.negative_prompt = Textarea(placeholder='Negative prompt...', layout=Layout(width="50%"))
13 | self.width_field = IntText(value=512, layout=Layout(width='200px'), description="width: ")
14 | self.height_field = IntText(value=768, layout=Layout(width='200px'), description="height: ")
15 | self.steps_field = IntSlider(value=25, min=1, max=100, description="Steps: ")
16 | self.cfg_field = FloatSlider(value=7, min=1, max=20, step=0.5, description="CFG: ")
17 | self.seed_field = IntText(value=-1, description="seed: ")
18 | self.batch_field = IntText(value=1, layout=Layout(width='150px'), description="Batch size ")
19 | self.count_field = IntText(value=1, layout=Layout(width='150px'), description="Sample count ")
20 | self.clip_skip = None
21 | self.cfg_rescale = 0.0
22 | self.prompt_preprocessors = []
23 |
24 | def render(self):
25 | """Render UI widgets."""
26 | l1 = Label("Positive prompt:", layout=Layout(width="50%"))
27 | l2 = Label("Negative prompt:", layout=Layout(width="50%"))
28 | prompts = HBox([self.positive_prompt, self.negative_prompt])
29 |
30 | def swap_dims(b):
31 | self.width_field.value, self.height_field.value = self.height_field.value, self.width_field.value
32 |
33 | swap_btn = Button(description="⇆", layout=Layout(width='40px'))
34 | swap_btn.on_click(swap_dims)
35 |
36 | size_box = HBox([self.width_field, self.height_field, swap_btn])
37 | batch_box = HBox([self.batch_field, self.count_field])
38 |
39 | display(HBox([l1, l2]), prompts, size_box, self.steps_field, self.cfg_field, self.seed_field, batch_box)
40 |
41 | def get_size(self):
42 | """Get current (width, height) values"""
43 | return (self.width_field.value, self.height_field.value)
44 |
45 | def save_image_with_metadata(self, image, path, additional_data = ""):
46 | meta = PngInfo()
47 | meta.add_text("Data", self._metadata + additional_data)
48 | image.save(path, pnginfo=meta)
49 |
50 | def get_positive_prompt(self):
51 | return self.preprocess_prompt(self.positive_prompt.value)
52 |
53 | def get_negative_prompt(self):
54 | return self.preprocess_prompt(self.negative_prompt.value)
55 |
56 | def preprocess_prompt(self, prompt : str):
57 | for preprocessor in self.prompt_preprocessors:
58 | prompt = preprocessor.process(prompt)
59 | return prompt
60 |
61 | @property
62 | def metadata(self):
63 | return self._metadata
64 |
65 | @property
66 | def sample_count(self):
67 | return self.count_field.value
68 |
69 | def get_metadata_string(self):
70 | return f"\nPrompt: {self.positive_prompt.value}\nNegative: {self.negative_prompt.value}\nCGF: {self.cfg_field.value} Steps {self.steps_field.value} "
71 |
72 | def get_dict_to_cache(self):
73 | return {
74 | "prompt" : self.positive_prompt.value,
75 | "n_prompt" : self.negative_prompt.value,
76 | "CGF" : self.cfg_field.value,
77 | "steps" : self.steps_field.value,
78 | "h" : self.height_field.value,
79 | "w" : self.width_field.value,
80 | "batch" : self.batch_field.value,
81 | }
82 |
83 | def load_cache(self, cache):
84 | self.positive_prompt.value = cache["prompt"]
85 | self.negative_prompt.value = cache["n_prompt"]
86 | self.cfg_field.value = cache["CGF"]
87 | self.steps_field.value = cache["steps"]
88 | self.height_field.value = cache["h"]
89 | self.width_field.value = cache["w"]
90 | self.batch_field.value = cache["batch"]
91 |
92 | def _ipython_display_(self):
93 | self.render()
94 |
--------------------------------------------------------------------------------
/ColabUI/CanvasHolder.py:
--------------------------------------------------------------------------------
1 | from ipycanvas import MultiCanvas, hold_canvas
2 | from ipywidgets import ColorPicker, IntSlider, link, AppLayout, HBox, VBox, Button, Checkbox
3 | from ipywidgets import Image
4 | import PIL
5 |
6 | class DrawContext:
7 | def __init__(self):
8 | self.drawing = False
9 | self.position = None
10 | self.shape = []
11 |
12 | class CanvasHolder:
13 | def __init__(self, width=512, height=512, line_color="#749cb8"):
14 | self.multicanvas = MultiCanvas(2, width=width, height=height)
15 | self.canvas = self.multicanvas[1] #foreground
16 | self.background = self.multicanvas[0]
17 | self.context = DrawContext()
18 | self.save_file = "temp.png"
19 | self.scale_factor = 1
20 |
21 | self.canvas.on_mouse_down(self.on_mouse_down)
22 | self.canvas.on_mouse_move(self.on_mouse_move)
23 | self.canvas.on_mouse_up(self.on_mouse_up)
24 |
25 | self.picker = ColorPicker(description="Color:", value=line_color)
26 | link((self.picker, "value"), (self.canvas, "stroke_style"))
27 | link((self.picker, "value"), (self.canvas, "fill_style"))
28 | self.canvas.stroke_style = line_color
29 |
30 | self.line_width_slider = IntSlider(5, min=1, max=20, description="Line Width:")
31 | link((self.line_width_slider, "value"), (self.canvas, "line_width"))
32 |
33 | self.fill_check = Checkbox(value=False, description="Fill shape")
34 |
35 | self.clear_btn = Button(description="Clear")
36 | def clear_handler(b):
37 | self.set_dirty()
38 | with hold_canvas(): self.canvas.clear()
39 | self.clear_btn.on_click(clear_handler)
40 |
41 | self.clear_back_btn = Button(description="Clear Background")
42 | def clear_background_handler(b):
43 | with hold_canvas(): self.background.clear()
44 | self.clear_back_btn.on_click(clear_background_handler)
45 |
46 | self.save_btn = Button(description="Save")
47 | def save_handler(b):
48 | with hold_canvas(): self.save_mask()
49 | self.save_btn.on_click(save_handler)
50 |
51 | def save_to_file(*args, **kwargs):
52 | self.canvas.to_file(self.save_file)
53 | self.canvas.observe(save_to_file, "image_data")
54 |
55 | def on_mouse_down(self, x, y):
56 | self.set_dirty()
57 | self.context.drawing = True
58 | self.context.position = (x, y)
59 | self.context.shape = [self.context.position]
60 |
61 | def on_mouse_move(self, x, y):
62 | if not self.context.drawing: return
63 | self.context.shape.append((x, y))
64 | with hold_canvas():
65 | self.canvas.stroke_line(self.context.position[0], self.context.position[1], x, y)
66 | if self.canvas.line_width > 3: self.canvas.fill_circle(x, y, 0.5 * self.canvas.line_width)
67 | self.context.position = (x, y)
68 |
69 | def on_mouse_up(self, x, y):
70 | self.context.drawing = False
71 | with hold_canvas():
72 | self.canvas.stroke_line(self.context.position[0], self.context.position[1], x, y)
73 | if self.fill_check.value: self.canvas.fill_polygon(self.context.shape)
74 | self.context.shape = []
75 |
76 | def add_background(self, filepath, fit_image = False, fit_canvas = True, scale_factor = 1):
77 | with hold_canvas():
78 | self.background.clear()
79 | self.scale_factor = scale_factor
80 | img = Image.from_file(filepath) #Image widget
81 | img_size = PIL.Image.open(filepath).size
82 | if scale_factor != 1:
83 | img_size = (scale_factor * img_size[0], scale_factor * img_size[1])
84 |
85 | if fit_canvas:
86 | self.set_size(img_size[0], img_size[1])
87 |
88 | if fit_image:
89 | self.background.draw_image(img, 0,0, width=self.multicanvas.width, height=self.multicanvas.height)
90 | else:
91 | self.background.draw_image(img, 0,0, width=img_size[0], height=img_size[1])
92 |
93 | def save_mask(self, filepath = "temp.png"):
94 | self.save_file = filepath
95 | self.canvas.sync_image_data = True #optimization hack to not sync image all the time
96 | self.save_btn.button_style = "success"
97 |
98 | def set_dirty(self):
99 | self.canvas.sync_image_data = False
100 | self.save_btn.button_style = ""
101 |
102 | def set_size(self, width, height):
103 | self.set_dirty()
104 | self.multicanvas.width = width
105 | self.multicanvas.height = height
106 | # This stuff gets reset for some reason
107 | self.canvas.stroke_style = self.picker.value
108 | self.canvas.fill_style = self.picker.value
109 | self.canvas.line_width = self.line_width_slider.value
110 |
111 | @property
112 | def render_element(self):
113 | return HBox([self.multicanvas,
114 | VBox([self.picker, self.line_width_slider, self.fill_check,
115 | HBox([self.clear_btn, self.save_btn]),
116 | self.clear_back_btn
117 | ])
118 | ])
119 |
120 | def render(self):
121 | display(self.render_element)
122 |
123 | def _ipython_display_(self):
124 | self.render()
125 |
--------------------------------------------------------------------------------
/ColabUI/ColabWrapper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler
4 | from diffusers import AutoencoderKL
5 | from .FlairView import FlairView
6 | from .HugginfaceModelIndex import HugginfaceModelIndex
7 | from .LoraDownloader import LoraDownloader
8 | from .LoraApplyer import LoraApplyer
9 | from .SettingsTabs import SettingsTabs
10 | from .Img2ImgRefinerUI import Img2ImgRefinerUI
11 | from ..utils.preview import get_image_previews, try_create_dir
12 | from ..utils.image_utils import save_image_with_metadata
13 |
14 | class ColabWrapper:
15 | def __init__(self, output_dir: str = "outputs/txt2img", img2img_dir: str = "outputs/img2img",
16 | favourite_dir: str = "outputs/favourite"):
17 | self.output_dir = output_dir
18 | self.img2img_dir = img2img_dir
19 | self.favourite_dir = favourite_dir
20 | self.output_index = 0
21 | self.cache = None
22 | self.lora_cache = dict()
23 | self.pipe = None
24 | self.custom_pipeline = None
25 |
26 | try_create_dir(output_dir)
27 | try_create_dir(favourite_dir)
28 |
29 | def render_model_index(self, filepath: str):
30 | self.model_index = HugginfaceModelIndex(filepath)
31 | self.model_index.render()
32 |
33 | def load_model(self, pipeline_interface, custom_pipeline:str = "lpw_stable_diffusion"):
34 | model_id, from_ckpt = self.model_index.get_model_id()
35 | loader_func = pipeline_interface.from_single_file if from_ckpt else pipeline_interface.from_pretrained
36 |
37 | self.custom_pipeline = custom_pipeline
38 |
39 | #TODO we can try catch here, and load file itself if diffusers doesn't want to load it for us
40 | self.pipe = loader_func(model_id,
41 | custom_pipeline=self.custom_pipeline,
42 | torch_dtype=torch.float16).to("cuda")
43 | self.pipe.safety_checker = None
44 | self.pipe.enable_xformers_memory_efficient_attention()
45 |
46 | def load_vae(self, id_or_path: str, subfolder: str):
47 | vae = AutoencoderKL.from_pretrained(id_or_path, subfolder=subfolder, torch_dtype=torch.float16).to("cuda")
48 | self.pipe.vae = vae
49 |
50 | def load_textual_inversion(self, path: str, filename: str):
51 | self.pipe.load_textual_inversion(path, weight_name=filename)
52 |
53 | def load_textual_inversions(self, root_folder: str):
54 | for path, subdirs, files in os.walk(root_folder):
55 | for name in files:
56 | try:
57 | ext = os.path.splitext(name)[1]
58 | if ext != ".pt" and ext != ".safetensors": continue
59 | self.pipe.load_textual_inversion(path, weight_name=name)
60 | print(path, name)
61 | except: pass
62 |
63 | def render_lora_loader(self, output_dir="Lora"):
64 | self.lora_downloader = LoraDownloader(self.pipe, output_dir=output_dir, cache=self.lora_cache)
65 | self.lora_downloader.render()
66 |
67 | def render_lora_ui(self):
68 | self.lora_ui = LoraApplyer(self.pipe, self.lora_downloader)
69 | self.lora_ui.render()
70 |
71 | def render_generation_ui(self, ui):
72 | self.ui = ui
73 | if (self.cache is not None): self.ui.load_cache(self.cache)
74 |
75 | self.settings = SettingsTabs(self)
76 |
77 | self.flair = FlairView(ui)
78 |
79 | self.settings.render()
80 | self.ui.render()
81 | self.flair.render()
82 |
83 | def restore_cached_ui(self):
84 | if (self.cache is not None): self.ui.load_cache(self.cache)
85 |
86 | def generate(self):
87 | self.cache = self.ui.get_dict_to_cache()
88 | paths = []
89 |
90 | for sample in range(self.ui.sample_count):
91 | results = self.ui.generate(self.pipe)
92 | for i, image in enumerate(results.images):
93 | path = os.path.join(self.output_dir, f"{self.output_index:05}.png")
94 | save_image_with_metadata(image, path, self.ui, f"Batch: {i}\n")
95 | self.output_index += 1
96 | paths.append(path)
97 |
98 | if self.settings.display_previews:
99 | display(get_image_previews(paths, 512, 512, favourite_dir=self.favourite_dir))
100 |
101 | return paths
102 |
103 | #StableDiffusionXLImg2ImgPipeline
104 | def render_refiner_ui(self, pipeline_interface):
105 | components = self.pipe.components
106 | self.img2img_pipe = pipeline_interface(**components)
107 |
108 | self.refiner_ui = Img2ImgRefinerUI(self.ui)
109 | self.refiner_ui.render()
110 |
111 | def refiner_generate(self, output_dir: str = None):
112 | results = self.refiner_ui.generate(self.img2img_pipe)
113 | paths = []
114 |
115 | if output_dir is None: output_dir = self.img2img_dir
116 | for i, image in enumerate(results.images):
117 | path = os.path.join(output_dir, f"{self.output_index:05}.png")
118 | save_image_with_metadata(image, path, self.refiner_ui, f"Batch: {i}\n")
119 | self.output_index += 1
120 | paths.append(path)
121 |
122 | if self.settings.display_previews:
123 | display(get_image_previews(paths, 512, 512, favourite_dir=self.favourite_dir))
124 | return paths
125 |
126 | def render_inpaint_ui(self, pipeline_interface, ui):
127 | components = self.pipe.components
128 | self.inpaint_pipe = pipeline_interface(**components)
129 |
130 | self.inpaint_ui = ui
131 | self.inpaint_ui.render()
132 |
133 | def inpaint_generate(self, output_dir: str = None):
134 | results = self.inpaint_ui.generate(self.inpaint_pipe)
135 | paths = []
136 |
137 | if output_dir is None: output_dir = self.img2img_dir
138 | for i, image in enumerate(results.images):
139 | path = os.path.join(output_dir, f"{self.output_index:05}.png")
140 | save_image_with_metadata(image, path, self.inpaint_ui, f"Batch: {i}\n")
141 | self.output_index += 1
142 | paths.append(path)
143 |
144 | if self.settings.display_previews:
145 | display(get_image_previews(paths, 512, 512, favourite_dir=self.favourite_dir))
146 | return paths
147 |
--------------------------------------------------------------------------------
/ColabUI/DiffusionImg2ImgUI.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from IPython.display import display
3 | from ipywidgets import FloatSlider, Dropdown, HBox, Text, Button, FileUpload
4 | from PIL import Image
5 | import ipywidgets as widgets
6 | import io
7 | from .BaseUI import BaseUI
8 |
9 | class DiffusionImg2ImgUI(BaseUI):
10 | _current_file :str
11 | _current_image :Image
12 |
13 | def __init__(self):
14 | super().__init__()
15 | self.__generator = torch.Generator(device="cuda")
16 | self.strength_field = FloatSlider(value=0.75, min=0, max=1, step=0.05, description="Strength: ")
17 | self.path_field = Text(description="Url:", placeholder="Image path...")
18 | self.upload_button = Button(description='Choose')
19 | self.image_preview = widgets.Image()
20 |
21 | def uploader_eventhandler(b):
22 | self._current_file = self.path_field.value
23 | temp = widgets.Image().from_file(self._current_file)
24 | self.image_preview.value = temp.value
25 | self._current_image = Image.open(self._current_file)
26 | self.set_size(self._current_image)
27 | self.upload_button.on_click(uploader_eventhandler)
28 |
29 | def render(self):
30 | super().render()
31 | display(self.strength_field, HBox([self.path_field, self.upload_button]), self.image_preview)
32 |
33 | def set_size(self, image, scaling_factor=1):
34 | (self.width_field.value, self.height_field.value) = tuple(scaling_factor * x for x in image.size)
35 |
36 | def generate(self, pipe, init_image=None, generator=None):
37 | """Generate images given DiffusionPipeline, and settings set in UI."""
38 | if self.seed_field.value >= 0:
39 | seed = self.seed_field.value
40 | else:
41 | seed = self.__generator.seed()
42 |
43 | if init_image is None: init_image = self._current_image
44 |
45 | g = torch.cuda.manual_seed(seed)
46 | self._metadata = self.get_metadata_string() + f"Seed: {seed} "
47 |
48 | init_image = init_image.convert('RGB')
49 | init_image = init_image.resize((self.width_field.value, self.height_field.value), resample=Image.BILINEAR)
50 |
51 | results = pipe(image=init_image,
52 | prompt=self.positive_prompt.value,
53 | negative_prompt=self.negative_prompt.value,
54 | num_inference_steps=self.steps_field.value,
55 | num_images_per_prompt = self.batch_field.value,
56 | guidance_scale=self.cfg_field.value,
57 | strength=self.strength_field.value,
58 | generator=g)
59 | return results
60 |
61 | def get_dict_to_cache(self):
62 | cache = super().get_dict_to_cache()
63 | cache["image_path"] = self.path_field.value
64 | return cache
65 |
66 | def load_cache(self, cache):
67 | super().load_cache(cache)
68 | self.path_field.value = cache["image_path"]
69 |
--------------------------------------------------------------------------------
/ColabUI/DiffusionPipelineUI.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .BaseUI import BaseUI
3 |
4 | class DiffusionPipelineUI(BaseUI):
5 | def __init__(self):
6 | super().__init__()
7 | self.__generator = torch.Generator(device="cuda")
8 |
9 | def generate(self, pipe, generator = None):
10 | """Generate images given DiffusionPipeline, and settings set in UI."""
11 | if self.seed_field.value >= 0:
12 | seed = self.seed_field.value
13 | else:
14 | seed = self.__generator.seed()
15 |
16 | g = torch.cuda.manual_seed(seed)
17 | self._metadata = self.get_metadata_string() + f"Seed: {seed} "
18 |
19 | results = pipe(prompt=self.get_positive_prompt(),
20 | negative_prompt=self.negative_prompt.value,
21 | num_inference_steps=self.steps_field.value,
22 | num_images_per_prompt = self.batch_field.value,
23 | guidance_scale=self.cfg_field.value,
24 | guidance_rescale=self.cfg_rescale,
25 | generator=g, clip_skip=self.clip_skip,
26 | height=self.height_field.value, width=self.width_field.value)
27 | return results
28 |
--------------------------------------------------------------------------------
/ColabUI/FlairView.py:
--------------------------------------------------------------------------------
1 | import os
2 | import typing
3 | from IPython.display import display, Javascript
4 | from ipywidgets import Output, Text, Layout, Button, HBox, VBox, Dropdown, Accordion, Checkbox
5 | from ..utils.empty_output import EmptyOutput
6 | from ..utils.downloader import download_ckpt
7 | from ..utils.ui import SwitchButton
8 |
9 | class FlairItem:
10 | __data : list[str]
11 |
12 | def __init__(self, ui, out:Output = None, index : int = -1):
13 | self.ui = ui
14 | if out is None: out = EmptyOutput()
15 | self.out = out
16 | self.clipboard_output = Output()
17 | self.index = index
18 |
19 | self.url_text = Text(placeholder="Path or url to load", layout=Layout(width="40%"))
20 | if index >= 0: self.url_text.description = f"{index + 1}:"
21 | self.url_text.description_tooltip = "Path to a file with a list of items, or a url"
22 |
23 | self.load_button = Button(description="Load", layout=Layout(width='50px'))
24 | def load_btn_click(b):
25 | self.out.clear_output()
26 | with self.out:
27 | self.load_data(self.url_text.value)
28 | if self.__data is not None and len(self.__data) > 0:
29 | self.switch_to_dropdown()
30 | self.load_button.on_click(load_btn_click)
31 |
32 | self.load_view = HBox([self.url_text, self.load_button])
33 |
34 | self.flair_dropdown = Dropdown(layout=Layout(width="40%"))
35 | if index >= 0: self.flair_dropdown.description = f"{index + 1}:"
36 | self.add_button = Button(description="Add", layout=Layout(width='80px'))
37 | def add_btn_click(b):
38 | self.ui.positive_prompt.value += f", by {self.flair_dropdown.value}"
39 | self.add_button.on_click(add_btn_click)
40 |
41 | self.copy_button = Button(description="Copy", tooltip="Copy to clipboard", layout=Layout(width='80px'))
42 | def copy_event_handler(b):
43 | with self.clipboard_output:
44 | text = self.flair_dropdown.value
45 | display(Javascript(f"navigator.clipboard.writeText('{text}');"))
46 | self.copy_button.on_click(copy_event_handler)
47 |
48 | self.preprocess_btn = SwitchButton("Preprocess", "Preprocess", Layout(width='100px'),
49 | self.preprocessor_on, self.preprocessor_off)
50 | self.preprocess_btn.button.tooltip = f"Automatically subsitute {{{index + 1}}} for an element of this list in prompt"
51 | self.subsitute_index = 0
52 | self.subsitute_to_random = Checkbox(value=False, description="Random", indent=False)
53 |
54 | self.dropdown_view = HBox([self.flair_dropdown, self.add_button, self.copy_button,
55 | self.preprocess_btn.render_element, self.subsitute_to_random])
56 | self.dropdown_view.layout.visibility = "hidden"
57 | self.dropdown_view.layout.height = "0px"
58 |
59 | def switch_to_dropdown(self):
60 | self.load_view.layout.visibility = "hidden"
61 | self.dropdown_view.layout.visibility = "visible"
62 | self.dropdown_view.layout.height = self.load_view.layout.height
63 | self.load_view.layout.height = "0px"
64 | self.flair_dropdown.options = self.__data
65 | self.flair_dropdown.description_tooltip = self.url_text.value
66 |
67 | @property
68 | def render_element(self):
69 | return VBox([self.load_view, self.dropdown_view, self.clipboard_output])
70 |
71 | def render(self):
72 | display(self.render_element)
73 |
74 | def _ipython_display_(self):
75 | self.render()
76 |
77 | def preprocessor_on(self):
78 | self.ui.prompt_preprocessors.append(self)
79 |
80 | def preprocessor_off(self):
81 | self.ui.prompt_preprocessors.remove(self)
82 |
83 | def process(self, prompt:str):
84 | #TODO guard index
85 | item = self.__data[self.subsitute_index]
86 | self.step_subsitute_index()
87 | return prompt.replace(f"{{{self.index + 1}}}", item)
88 |
89 | def set_subsitute_index(self, index:int = 0):
90 | self.subsitute_index = index
91 |
92 | def step_subsitute_index(self):
93 | if self.subsitute_to_random.value:
94 | self.subsitute_index = random.randint(0, len(self.__data))
95 | else:
96 | self.subsitute_index += 1
97 | if self.subsitute_index >= len(self.__data): self.subsitute_index = 0
98 |
99 | def load_data(self, path:str):
100 | if os.path.isfile(path):
101 | self.__data = self.read_file(path)
102 | else:
103 | f = download_ckpt(path)
104 | self.__data = self.read_file(f)
105 |
106 | @staticmethod
107 | def read_file(path:str):
108 | with open(path) as f: data = f.readlines()
109 | data = [x.strip('\n') for x in data]
110 | return data
111 |
112 | class FlairView:
113 | __items : list[FlairItem]
114 |
115 | def __init__(self, ui):
116 | self.ui = ui
117 | self.vbox = VBox([])
118 |
119 | self.output = Output(layout={'border': '1px solid black'})
120 | self.clear_output_button = Button(description="Clear Log")
121 | def on_clear_clicked(b):
122 | self.output.clear_output()
123 | self.clear_output_button.on_click(on_clear_clicked)
124 | right_align_layout = Layout(display='flex', flex_flow='column', align_items='flex-end', width='99%')
125 |
126 | self.add_button = Button(description="Add", layout=Layout(width='80px'), button_style='info')
127 | def add_btn_click(b):
128 | f = FlairItem(self.ui, self.output, len(self.vbox.children))
129 | self.vbox.children=tuple(list(self.vbox.children) + [f.render_element])
130 | self.add_button.on_click(add_btn_click)
131 |
132 | self.foldout = Accordion(children=[VBox([self.vbox, self.add_button, self.output,
133 | HBox(children=[self.clear_output_button],layout=right_align_layout)])])
134 | self.foldout.set_title(0, "Flair tags")
135 | self.foldout.selected_index = None
136 |
137 | @property
138 | def render_element(self):
139 | return self.foldout
140 |
141 | def render(self):
142 | display(self.render_element)
143 |
144 | def _ipython_display_(self):
145 | self.render()
146 |
--------------------------------------------------------------------------------
/ColabUI/HugginfaceModelIndex.py:
--------------------------------------------------------------------------------
1 | from IPython.display import display
2 | from ipywidgets import Dropdown, HTML, HBox, Layout, Text, Checkbox
3 | import json
4 | import requests
5 | import werkzeug
6 | import os
7 | from tqdm.notebook import tqdm
8 | from ..utils.downloader import download_ckpt
9 |
10 | class HugginfaceModelIndex:
11 | def __init__(self, filepath = "model_index.json"):
12 | with open(filepath) as f:
13 | self.data = json.load(f)
14 |
15 | self.model_link = HTML()
16 | self.from_ckpt = Checkbox(value=False, description="A1111 format")
17 |
18 | self.__setup_url_field()
19 | self.__setup_model_dropdown()
20 |
21 | def render(self):
22 | """Display ui"""
23 | self.__set_link_from_dict(self.model_dropdown.value)
24 | display(self.model_dropdown, HBox([self.url_text, self.from_ckpt]), self.model_link)
25 |
26 | def get_model_id(self):
27 | """Return model_id/url/local path of the model, and whether it should be loaded with from_ckpt"""
28 | if self.url_text.value != "":
29 | return self.__handle_url_value(self.url_text.value)
30 | return self.data[self.model_dropdown.value]["id"], False
31 |
32 | def __setup_model_dropdown(self):
33 | self.model_dropdown = Dropdown(
34 | options=[x for x in self.data],
35 | description=self.highlight("Model:"),
36 | )
37 | def dropdown_eventhandler(change):
38 | self.__set_link_from_dict(change.new)
39 | self.model_dropdown.observe(dropdown_eventhandler, names='value')
40 | self.model_dropdown.description_tooltip = "Choose model from model index"
41 |
42 | def __setup_url_field(self):
43 | self.url_text = Text(description="Url:", placeholder='Optional url or path...', layout=Layout(width="50%"))
44 | def url_eventhandler(change):
45 | self.__set_link_from_url(change.new)
46 | self.url_text.observe(url_eventhandler, names='value')
47 | self.url_text.description_tooltip = "Model_id, url, or local path for any other model not in index"
48 |
49 | def __set_link_from_dict(self, key):
50 | self.from_ckpt.value = False
51 | self.from_ckpt.disabled = True
52 | self.url_text.value = ""
53 | self.model_link.value = f"Model info: {key}"
54 | try:
55 | self.model_link.value += f"
Trigger prompt: {self.data[key]['trigger']}
"
56 | except: pass
57 |
58 | def __set_link_from_url(self, new_url):
59 | if new_url == "":
60 | self.__set_link_from_dict(self.model_dropdown.value)
61 | self.url_text.description = "Url:"
62 | self.model_dropdown.description = self.highlight("Model:")
63 | self.from_ckpt.value = False
64 | self.from_ckpt.disabled = True
65 | else:
66 | self.model_link.value = f"Model info: link"
67 | self.url_text.description = self.highlight("Url:")
68 | self.model_dropdown.description = "Model:"
69 | if new_url.startswith("https://civitai.com/api/download/models/") or new_url.endswith(".safetensors"):
70 | self.from_ckpt.value = True
71 | elif self.is_huggingface_model_id(new_url):
72 | self.from_ckpt.value = False
73 | elif os.path.exists(new_url):
74 | self.from_ckpt.value = os.path.isfile(new_url)
75 | self.from_ckpt.disabled = False
76 |
77 | def __handle_url_value(self, url):
78 | if self.is_huggingface_model_id(url):
79 | return url, False
80 | elif os.path.exists(url):
81 | return url, self.from_ckpt.value
82 | else:
83 | return download_ckpt(url), self.from_ckpt.value
84 |
85 | @staticmethod
86 | def highlight(str_to_highlight):
87 | return f"{str_to_highlight}"
88 |
89 | @staticmethod
90 | def is_huggingface_model_id(url):
91 | return len(url.split('/')) == 2
92 |
93 | def _ipython_display_(self):
94 | self.render()
95 |
--------------------------------------------------------------------------------
/ColabUI/ImageViewer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import isfile, join
3 | from ipywidgets import Dropdown, Button, Image, VBox, HBox
4 |
5 | class ImageViewer:
6 | def __init__(self, path: str = "outputs/favourite", width: int = 512, height: int = 512):
7 | self.path = path
8 | self.files = sorted([f for f in os.listdir(path) if isfile(join(path, f))])
9 | self.dropdown = Dropdown(
10 | options=[x for x in self.files],
11 | description="Image:",
12 | )
13 | self.preview = Image(
14 | width=width,
15 | height=height
16 | )
17 | if len(self.files) > 0: self.preview.value = open(join(self.path, self.files[0]), "rb").read()
18 |
19 | def dropdown_eventhandler(change):
20 | self.preview.value=open(join(self.path, change.new), "rb").read()
21 | self.dropdown.observe(dropdown_eventhandler, names='value')
22 |
23 | self.btn_refresh = Button(description="Refresh")
24 | def btn_handler(b):
25 | self.refresh()
26 | self.btn_refresh.on_click(btn_handler)
27 |
28 | def refresh(self):
29 | value = self.dropdown.value
30 | self.files = sorted([f for f in os.listdir(self.path) if isfile(join(self.path, f))])
31 | self.dropdown.options = [x for x in self.files]
32 | self.dropdown.value = value
33 |
34 | @property
35 | def render_element(self):
36 | return VBox([HBox([self.dropdown, self.btn_refresh]), self.preview])
37 |
38 | def render(self):
39 | display(self.render_element)
40 |
41 | def _ipython_display_(self):
42 | self.render()
43 |
--------------------------------------------------------------------------------
/ColabUI/Img2ImgRefinerUI.py:
--------------------------------------------------------------------------------
1 | import io
2 | import ipywidgets as widgets #PIL.Image vs ipywidgets.Image
3 | import numpy as np
4 | import os
5 | import torch
6 | from IPython.display import display
7 | from ipywidgets import FloatSlider, Dropdown, HBox, VBox, Text, Button, Label
8 | from PIL import Image
9 | from .BaseUI import BaseUI
10 | from ..utils.image_utils import load_image_metadata
11 |
12 | class Img2ImgRefinerUI:
13 | def __init__(self, ui: BaseUI):
14 | self.__generator = torch.Generator(device="cuda")
15 | self.__base_ui = ui
16 | self.__current_file = ""
17 | self.__current_img = None
18 | self.__preview_visible = True
19 |
20 | self.strength_field = FloatSlider(value=0.35, min=0, max=1, step=0.05, description="Strength: ")
21 | self.upscale_field = FloatSlider(value=1, min=1, max=4, step=0.25, description="Upscale factor: ")
22 | self.path_field = Text(description="Url:", placeholder="Image path...")
23 | self.load_button = Button(description="Choose")
24 | self.load_prompts_button = Button(description="Load prompts")
25 | self.image_preview = widgets.Image(width=300, height=400)
26 | self.hide_button = Button(description="Hide")
27 | self.size_label = Label(value="Size")
28 |
29 | def load_handler(b):
30 | if not os.path.isfile(self.path_field.value): return
31 | self.__current_file = self.path_field.value
32 | self.image_preview.value = open(self.__current_file, "rb").read()
33 | self.__current_img = Image.open(self.__current_file)
34 | self.size_label.value = f"{self.__current_img.size} → {self.get_upscaled_size(self.__current_img.size)}"
35 | self.load_button.on_click(load_handler)
36 |
37 | def load_prompts_handler(b):
38 | if not os.path.isfile(self.path_field.value): return
39 | prompt, negative = load_image_metadata(self.path_field.value)
40 | self.__base_ui.positive_prompt.value = prompt
41 | self.__base_ui.negative_prompt.value = negative
42 | self.load_prompts_button.on_click(load_prompts_handler)
43 |
44 | def hide_handler(b):
45 | self.__preview_visible = not self.__preview_visible
46 | self.image_preview.layout.width = "300px" if self.__preview_visible else "0px"
47 | self.hide_button.description = "Hide" if self.__preview_visible else "Show preview"
48 | self.hide_button.on_click(hide_handler)
49 |
50 | def on_slider_change(change):
51 | if self.__current_img is None: return
52 | self.size_label.value = f"{self.__current_img.size} → {self.get_upscaled_size(self.__current_img.size)}"
53 | self.upscale_field.observe(on_slider_change, 'value')
54 |
55 | def generate(self, pipe, init_image=None, generator=None):
56 | """Generate images given Img2Img Pipeline, and settings set in Base UI and Refiner UI."""
57 | if self.__base_ui.seed_field.value >= 0:
58 | seed = self.__base_ui.seed_field.value
59 | else:
60 | seed = self.__generator.seed()
61 |
62 | if init_image is None:
63 | init_image = self.__current_img
64 |
65 | init_image = init_image.convert('RGB')
66 | size = self.get_upscaled_size(init_image.size)
67 | init_image = init_image.resize(size, resample=Image.LANCZOS)
68 |
69 | g = torch.cuda.manual_seed(seed)
70 | self._metadata = self.__base_ui.get_metadata_string() + f"\nImg2Img Seed: {seed}, Noise Strength: {self.strength_field.value}, Upscale: {self.upscale_field.value} "
71 |
72 | results = pipe(image=init_image,
73 | prompt=self.__base_ui.positive_prompt.value,
74 | negative_prompt=self.__base_ui.negative_prompt.value,
75 | num_inference_steps=self.__base_ui.steps_field.value,
76 | num_images_per_prompt = self.__base_ui.batch_field.value,
77 | guidance_scale=self.__base_ui.cfg_field.value,
78 | guidance_rescale=self.__base_ui.cfg_rescale,
79 | strength=self.strength_field.value,
80 | width=size[0], height=size[1],
81 | generator=g)
82 | return results
83 |
84 | @property
85 | def metadata(self):
86 | return self._metadata
87 |
88 | def get_upscaled_size(self, size):
89 | return (int(self.upscale_field.value * size[0]), int(self.upscale_field.value * size[1]))
90 |
91 | def __create_color_picker(self):
92 | picker = widgets.ColorPicker(description='Pick a color', value='#000000', concise=False)
93 | btn = Button(description = "Apply")
94 | def load_color_handler(b):
95 | self.__current_img = self.__create_image_from_color(picker)
96 | self.size_label.value = f"{self.__current_img.size} → {self.get_upscaled_size(self.__current_img.size)}"
97 | f = io.BytesIO()
98 | self.__current_img.save(f, "png")
99 | self.image_preview.value = f.getvalue()
100 | btn.on_click(load_color_handler)
101 |
102 | foldout = widgets.Accordion(children=[HBox([picker, btn])])
103 | foldout.set_title(0, "From solid color")
104 | foldout.selected_index = None
105 | return foldout
106 |
107 | def __create_image_from_color(self, picker):
108 | h = picker.value.lstrip('#')
109 | rgb = tuple(int(h[i:i+2], 16) for i in (0, 2, 4))
110 | width = int(self.__base_ui.width_field.value)
111 | height = int(self.__base_ui.height_field.value)
112 | x = np.tile(rgb, (height, width, 1)).astype(np.uint8)
113 | return Image.fromarray(x)
114 |
115 | @property
116 | def render_element(self):
117 | return VBox([self.strength_field, self.upscale_field,
118 | HBox([self.path_field, self.load_button, self.load_prompts_button]),
119 | self.__create_color_picker(),
120 | self.size_label, self.image_preview, self.hide_button])
121 |
122 | def render(self):
123 | display(self.render_element)
124 |
125 | def _ipython_display_(self):
126 | self.render()
127 |
--------------------------------------------------------------------------------
/ColabUI/InpaintRefinerUI.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import torch
4 | import numpy as np
5 | from ipywidgets import Text, FloatSlider, HBox, VBox, Button, Accordion
6 | from PIL import Image
7 | from .CanvasHolder import CanvasHolder
8 |
9 | #TODO can blur: https://huggingface.co/docs/diffusers/using-diffusers/inpaint
10 |
11 | class InpaintRefinerUI:
12 | def __init__(self, ui):
13 | self.__generator = torch.Generator(device="cuda")
14 | self.__base_ui = ui
15 | self.__current_file = ""
16 | self.scale_factor = 0.5
17 |
18 | self.canvas_holder = CanvasHolder(200, 200)
19 | self.strength_field = FloatSlider(value=0.35, min=0, max=1, step=0.05, description="Strength: ")
20 | self.path_field = Text(description="Url:", placeholder="Image path...")
21 | self.load_button = Button(description="Choose")
22 | self.load_prompts_button = Button(description="Load prompts")
23 |
24 | def load_handler(b):
25 | if not os.path.isfile(self.path_field.value): return
26 | self.__current_file = self.path_field.value
27 | self.canvas_holder.add_background(self.path_field.value, scale_factor = self.scale_factor)
28 | self.load_button.on_click(load_handler)
29 |
30 | #TODO accoedion doesn't work for some reason
31 | self.accordion = Accordion(children=[self.canvas_holder.render_element])
32 | self.accordion.set_title(0, "Canvas")
33 |
34 | def generate(self, pipe, init_image=None, mask=None, generator=None):
35 | """Generate images given Inpaint Pipeline, and settings set in Base UI and Refiner UI."""
36 | if generator is None: generator = self.__generator
37 | if self.__base_ui.seed_field.value >= 0:
38 | seed = self.__base_ui.seed_field.value
39 | else:
40 | seed = generator.seed()
41 |
42 | if init_image is None:
43 | init_image = Image.open(self.__current_file)
44 | size = init_image.size
45 | # TODO check if mask size matches, or always scale to img?
46 | if mask is None:
47 | mask = Image.fromarray(self.canvas_holder.canvas.get_image_data()[:,:,-1])
48 | if self.scale_factor != 1:
49 | mask = mask.resize((int(mask.size[0] / self.scale_factor), int(mask.size[1] / self.scale_factor)))
50 |
51 | g = torch.cuda.manual_seed(seed)
52 | self._metadata = self.__base_ui.get_metadata_string() + f"\nImg2Img Seed: {seed}, Noise Strength: {self.strength_field.value} "
53 |
54 | results = pipe(image=init_image,
55 | mask_image=mask,
56 | prompt=self.__base_ui.positive_prompt.value,
57 | negative_prompt=self.__base_ui.negative_prompt.value,
58 | num_inference_steps=self.__base_ui.steps_field.value,
59 | num_images_per_prompt = self.__base_ui.batch_field.value,
60 | guidance_scale=self.__base_ui.cfg_field.value,
61 | guidance_rescale=self.__base_ui.cfg_rescale,
62 | strength=self.strength_field.value,
63 | width=size[0], height=size[1],
64 | generator=g)
65 | return results
66 |
67 | @property
68 | def metadata(self):
69 | return self._metadata
70 |
71 | @property
72 | def render_element(self):
73 | return VBox([self.strength_field,
74 | HBox([self.path_field, self.load_button, self.load_prompts_button]),
75 | self.canvas_holder.render_element])
76 |
77 | def render(self):
78 | display(self.render_element)
79 |
80 | def _ipython_display_(self):
81 | self.render()
82 |
--------------------------------------------------------------------------------
/ColabUI/KandinskyUI.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .BaseUI import BaseUI
3 |
4 | class KandinskyUI(BaseUI):
5 | def generate(self, pipe, sampler="p_sampler"):
6 | if self.seed_field.value > 0:
7 | torch.manual_seed(self.seed_field.value)
8 |
9 | self._metadata = self.get_metadata_string() + f"Seed: {self.seed_field.value} "
10 |
11 | images = model.generate_text2img(self.positive_prompt.value,
12 | negative_prior_prompt = self.negative_prompt.value,
13 | num_steps=self.steps_field.value,
14 | batch_size=self.batch_field.value,
15 | guidance_scale=self.cfg_field.value,
16 | h=self.height_field.value, w=self.width_field.value,
17 | sampler=sampler)
18 | return images
19 |
--------------------------------------------------------------------------------
/ColabUI/LoraApplyer.py:
--------------------------------------------------------------------------------
1 | from IPython.display import display
2 | from ipywidgets import Dropdown, Button, HBox, VBox, Layout, Text, FloatSlider, Output
3 | from ..utils.empty_output import EmptyOutput
4 |
5 | class LoraApplyer:
6 | def __init__(self, colab, out:Output = None, cache = None):
7 | self.colab = colab
8 | if out is None: out = EmptyOutput()
9 | self.out = out
10 | self.__cache = cache
11 | self.__applier_loras = dict()
12 | self.vbox = VBox()
13 | self.is_fused = False
14 |
15 | self.__setup_dropdown()
16 |
17 | self.add_button = Button(description="Add", layout=Layout(width='50px'))
18 | def add_lora(b):
19 | self.out.clear_output()
20 | with out: self.__add_lora(self.dropdown.value, 1)
21 | self.add_button.on_click(add_lora)
22 |
23 | self.fuse_button = Button(description="Fuse", layout=Layout(width='75px'), button_style='success')
24 | def on_fuse_btn(b):
25 | self.out.clear_output()
26 | with out: self.fuse_lora()
27 | self.fuse_button.on_click(on_fuse_btn)
28 | self.fuse_button.tooltip = "Merge loras into the model weights to speed-up inference"
29 |
30 | self.unfuse_button = Button(description="Unfuse", layout=Layout(width='75px'), button_style='danger')
31 | def on_unfuse_btn(b):
32 | self.out.clear_output()
33 | with out: self.unfuse_lora()
34 | self.unfuse_button.on_click(on_unfuse_btn)
35 | self.unfuse_button.tooltip = "Return model weights to the original state"
36 |
37 | self.__set_ui_fuse_state(False)
38 |
39 | def fuse_lora(self):
40 | if self.is_fused: return
41 | self.is_fused = True
42 | self.__set_ui_fuse_state(True)
43 | self.__apply_adapters()
44 | self.colab.pipe.fuse_lora()
45 |
46 | def unfuse_lora(self):
47 | if not self.is_fused: return
48 | self.is_fused = False
49 | self.__set_ui_fuse_state(False)
50 | self.colab.pipe.unfuse_lora()
51 |
52 | def __setup_dropdown(self):
53 | self.dropdown = Dropdown(
54 | options=[x for x in self.__cache],
55 | description="Lora:",
56 | )
57 | self.dropdown.description_tooltip = "Choose lora to load"
58 |
59 | def update_dropdown(self, new_adapter):
60 | self.dropdown.options = [x for x in self.__cache]
61 |
62 | def __add_lora(self, adapter: str, scale: float):
63 | if adapter in self.__applier_loras: return
64 | self.__applier_loras[adapter] = scale
65 | self.__rerender_items()
66 | self.__apply_adapters()
67 |
68 | def __rerender_items(self):
69 | w = []
70 | for adapter, scale in self.__applier_loras.items():
71 | slider = self.__add_lora_scale_slider(adapter, scale)
72 | remove_btn = self.__add_lora_remove_button(adapter)
73 | w.append(HBox([slider, remove_btn]))
74 |
75 | self.vbox.children = w
76 |
77 | def __add_lora_scale_slider(self, adapter, scale):
78 | slider = FloatSlider(min=0, max=1.5, value=scale, step=0.05, description=adapter)
79 | def value_changed(change):
80 | self.__applier_loras[adapter] = change.new
81 | self.__apply_adapters()
82 | slider.observe(value_changed, 'value')
83 | return slider
84 |
85 | def __add_lora_remove_button(self, adapter):
86 | remove_btn = Button(description="X", layout=Layout(width='40px'))
87 | def on_button(b):
88 | del self.__applier_loras[adapter]
89 | self.__rerender_items()
90 | self.__apply_adapters()
91 | remove_btn.on_click(on_button)
92 | return remove_btn
93 |
94 | def __apply_adapters(self):
95 | if self.is_fused: return
96 | self.colab.pipe.enable_lora() #TODO use separate button
97 | self.colab.pipe.set_adapters([x for x in self.__applier_loras.keys()],
98 | adapter_weights=[x for x in self.__applier_loras.values()])
99 |
100 | def __set_ui_fuse_state(self, value):
101 | self.add_button.disabled = value
102 | self.fuse_button.disabled = value
103 | self.unfuse_button.disabled = not value
104 | self.fuse_button.layout.width = "50px" if value else "75px"
105 | self.unfuse_button.layout.width = "50px" if not value else "75px"
106 | #self.fuse_button.layout.visibility = "hidden" if value else "visible"
107 | #self.unfuse_button.layout.visibility = "visible" if value else "hidden"
108 |
109 | @property
110 | def render_element(self):
111 | return VBox([HBox([self.dropdown, self.add_button, self.fuse_button, self.unfuse_button]), self.vbox])
112 |
113 | def render(self):
114 | display(self.render_element)
115 |
116 | def _ipython_display_(self):
117 | self.render()
118 |
--------------------------------------------------------------------------------
/ColabUI/LoraChoice.py:
--------------------------------------------------------------------------------
1 | import markdown
2 | from ipywidgets import VBox, HTML, Output
3 | from .LoraDownloader import LoraDownloader
4 | from .LoraApplyer import LoraApplyer
5 |
6 | class LoraChoice:
7 | def __init__(self, colab, out:Output = None, download_dir:str = "Lora", lora_cache:dict = None):
8 | self.colab = colab
9 |
10 | self.label_download = HTML(markdown.markdown("## Load Lora"))
11 | self.lora_downloader = LoraDownloader(colab, out, output_dir=download_dir, cache=lora_cache)
12 | self.label_apply = HTML(markdown.markdown("## Apply Lora"))
13 | self.lora_ui = LoraApplyer(colab, out, cache=self.lora_downloader.cache)
14 |
15 | self.lora_downloader.on_load_event.clear_callbacks()
16 | self.lora_downloader.on_load_event.add_callback(self.lora_ui.update_dropdown)
17 |
18 | @property
19 | def render_element(self):
20 | return VBox([self.label_download, self.lora_downloader.render_element, self.label_apply, self.lora_ui.render_element])
21 |
22 | def render(self):
23 | display(self.render_element)
24 |
25 | def _ipython_display_(self):
26 | self.render()
--------------------------------------------------------------------------------
/ColabUI/LoraDownloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gdown
3 | from IPython.display import display
4 | from ipywidgets import Dropdown, Button, HBox, Layout, Text, Output
5 | from ..utils.downloader import download_ckpt
6 | from ..utils.event import Event
7 | from ..utils.empty_output import EmptyOutput
8 |
9 | class LoraDownloader:
10 | __cache = dict()
11 |
12 | def __init__(self, colab, out:Output = None, output_dir:str = "Lora", cache = None):
13 | self.output_dir = output_dir
14 | self.colab = colab
15 | if out is None: out = EmptyOutput()
16 | self.out = out
17 |
18 | if cache is not None: self.load_cache(cache)
19 |
20 | self.url_text = Text(placeholder="Lora url:", layout=Layout(width="40%"))
21 | def url_changed(change):
22 | self.adapter_field.value = ""
23 | self.url_text.observe(url_changed, 'value')
24 | self.url_text.description_tooltip = "Url or filepath to Lora. Supports google drive links."
25 |
26 | self.adapter_field = Text(placeholder="Adapter name", layout=Layout(width='150px'))
27 | self.adapter_field.description_tooltip = "Custom name for Lora. Default is filename."
28 |
29 | self.load_button = Button(description="Load", layout=Layout(width='50px'))
30 | def load(b):
31 | self.out.clear_output()
32 | with self.out:
33 | self.load_lora(self.url_text.value, self.adapter_field.value)
34 | self.load_button.on_click(load)
35 |
36 | self.repair_button = Button(description="Repair", layout=Layout(width='75px'), button_style='info')
37 | def on_repair_btn(b):
38 | self.out.clear_output()
39 | with self.out: self.repair()
40 | self.repair_button.on_click(on_repair_btn)
41 | self.repair_button.tooltip = "Click to refresh list of loaded Loras in case of errors"
42 |
43 | self.on_load_event = Event()
44 |
45 | def load_lora(self, url, adapter_name=""):
46 | if os.path.isfile(url):
47 | filepath = url
48 | elif url.startswith("https://drive.google.com"):
49 | id = gdown.parse_url.parse_url(url)[0]
50 | if not self.output_dir.endswith("/"): self.output_dir += "/"
51 | filepath = gdown.download(f"https://drive.google.com/uc?id={id}", self.output_dir)
52 | elif url.startswith("https://civitai.com/api/download/models/") or url.endswith(".safetensors"):
53 | filepath = download_ckpt(url, self.output_dir)
54 | else:
55 | print("Error")
56 | return
57 |
58 | filename = os.path.basename(filepath)
59 | if adapter_name == "":
60 | adapter_name = os.path.splitext(filename)[0]
61 | self.adapter_field.value = adapter_name
62 | self.colab.pipe.load_lora_weights(self.output_dir, weight_name=filename, adapter_name=adapter_name)
63 | self.__cache[adapter_name] = filepath
64 |
65 | self.on_load_event.invoke(adapter_name)
66 |
67 | def load_cache(self, cache):
68 | for adapter_name, filepath in cache.items():
69 | try:
70 | filename = os.path.basename(filepath)
71 | self.colab.pipe.load_lora_weights(os.path.dirname(filepath), weight_name=filename, adapter_name=adapter_name)
72 | self.__cache[adapter_name] = filepath
73 | except: pass
74 |
75 | def repair(self):
76 | adapters_dict = self.colab.pipe.get_list_adapters()
77 | print(adapters_dict)
78 | adapters = {x for module in adapters_dict.values() for x in module}
79 | for a in adapters:
80 | self.__cache[a] = None
81 | self.on_load_event.invoke("") #no need to call for each adapter, TODO argument
82 |
83 | @property
84 | def cache(self):
85 | return self.__cache
86 |
87 | @property
88 | def render_element(self):
89 | return HBox([self.url_text, self.adapter_field, self.load_button, self.repair_button])
90 |
91 | def render(self):
92 | display(self.render_element)
93 |
94 | def _ipython_display_(self):
95 | self.render()
96 |
--------------------------------------------------------------------------------
/ColabUI/SamplerChoice.py:
--------------------------------------------------------------------------------
1 | from ipywidgets import Dropdown, Output
2 | from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler
3 | from ..utils.empty_output import EmptyOutput
4 |
5 | class SamplerChoice:
6 | def __init__(self, colab, out:Output = None):
7 | self.colab = colab
8 |
9 | if out is None: out = EmptyOutput()
10 | self.out = out
11 |
12 | self.dropdown = Dropdown(
13 | options=["Euler A", "DPM++", "DPM++ Karras", "UniPC"],
14 | description='Sampler:',
15 | )
16 |
17 | def dropdown_eventhandler(change):
18 | self.out.clear_output()
19 | with self.out:
20 | self.choose_sampler(self.colab.pipe, change.new)
21 |
22 | self.dropdown.observe(dropdown_eventhandler, names='value')
23 |
24 | def choose_sampler(self, pipe, sampler_name: str):
25 | config = pipe.scheduler.config
26 | match sampler_name:
27 | case "Euler A": sampler = EulerAncestralDiscreteScheduler.from_config(config)
28 | case "DPM++": sampler = DPMSolverMultistepScheduler.from_config(config)
29 | case "DPM++ Karras":
30 | sampler = DPMSolverMultistepScheduler.from_config(config)
31 | sampler.use_karras_sigmas = True
32 | case "UniPC":
33 | sampler = UniPCMultistepScheduler.from_config(config)
34 | case _: raise NameError("Unknown sampler")
35 | pipe.scheduler = sampler
36 | print(f"Sampler '{sampler_name}' chosen")
37 |
38 | @property
39 | def render_element(self):
40 | return self.dropdown
41 |
42 | def render(self):
43 | display(self.render_element)
44 |
45 | def _ipython_display_(self):
46 | self.render()
47 |
--------------------------------------------------------------------------------
/ColabUI/SettingsTabs.py:
--------------------------------------------------------------------------------
1 | from ipywidgets import Button, Tab, VBox, Output, Accordion, IntSlider, Text, Checkbox, FloatSlider
2 | from .SamplerChoice import SamplerChoice
3 | from .VaeChoice import VaeChoice
4 | from .TextualInversionChoice import TextualInversionChoice
5 | from .LoraChoice import LoraChoice
6 | import os
7 |
8 | class SettingsTabs:
9 | #TODO config file, default import file
10 | def __init__(self, colab):
11 | self.colab = colab
12 | self.output = Output(layout={'border': '1px solid black'})
13 | self.clear_output_button = Button(description="Clear Log")
14 | def on_clear_clicked(b):
15 | self.output.clear_output()
16 | self.clear_output_button.on_click(on_clear_clicked)
17 |
18 | self.sampler_choice = SamplerChoice(self.colab, self.output)
19 | self.vae_choice = VaeChoice(self.colab, self.output)
20 | self.ti_choice = TextualInversionChoice(self.colab, self.output)
21 | self.lora_choice = LoraChoice(self.colab, self.output, lora_cache=colab.lora_cache)
22 | self.general_settings = self.__general_settings()
23 |
24 | self.tabs = Tab()
25 | self.tabs.children = [self.sampler_choice.render_element,
26 | self.vae_choice.render_element,
27 | self.ti_choice.render_element,
28 | self.lora_choice.render_element,
29 | self.general_settings]
30 |
31 | for i,x in enumerate(["Sampler", "VAE", "Textual Inversion", "Lora", "Other"]):
32 | self.tabs.set_title(i, x)
33 |
34 | self.accordion = Accordion(children=[VBox([self.tabs, self.clear_output_button, self.output])])
35 | self.accordion.set_title(0, "Settings")
36 | self.accordion.selected_index = None
37 |
38 | def __general_settings(self):
39 | self.show_preview_checkbox = Checkbox(value=True, description="Show output previews")
40 |
41 | clip_slider = IntSlider(value=self.colab.ui.clip_skip, min=0, max=4, description="Clip Skip")
42 | def clip_value_changed(change):
43 | self.output.clear_output()
44 | with self.output:
45 | self.colab.ui.clip_skip = change.new
46 | print(f"Clip Skip set to {change.new}")
47 | clip_slider.observe(clip_value_changed, "value")
48 |
49 | cfg_rescale_slider = FloatSlider(value=0, min=0, max=1, step = 0.05, description="CFG Rescale")
50 | cfg_rescale_slider.description_tooltip = "Bigger value fixes contrast and overexposure"
51 | def rescale_value_changed(change):
52 | with self.output: self.colab.ui.cfg_rescale = change.new
53 | cfg_rescale_slider.observe(rescale_value_changed, "value")
54 |
55 | favourite_dir = Text(placeholder="Path to folder", description="Favourites:")
56 | favourite_dir.value = self.colab.favourite_dir
57 | def favourite_value_changed(change):
58 | self.output.clear_output()
59 | with self.output:
60 | if (os.path.isdir(change.new)):
61 | self.colab.favourite_dir = change.new
62 | favourite_dir.description = "Favourites:"
63 | print(f"Favourites foulder changed to {change.new}")
64 | else:
65 | favourite_dir.description = self.highlight("Favourites:", color="red")
66 | favourite_dir.observe(favourite_value_changed, "value")
67 |
68 | return VBox([self.show_preview_checkbox, favourite_dir, clip_slider, cfg_rescale_slider])
69 |
70 | @staticmethod
71 | def highlight(str_to_highlight: str, color: str = "red") -> str:
72 | return f"{str_to_highlight}"
73 |
74 | @property
75 | def display_previews(self):
76 | return self.show_preview_checkbox.value
77 |
78 | @property
79 | def render_element(self):
80 | return self.accordion
81 |
82 | def render(self):
83 | display(self.render_element)
84 |
85 | def _ipython_display_(self):
86 | self.render()
87 |
--------------------------------------------------------------------------------
/ColabUI/TextualInversionChoice.py:
--------------------------------------------------------------------------------
1 | import os
2 | from ipywidgets import Text, Button, Layout, Combobox, VBox, Output
3 | from ..utils.empty_output import EmptyOutput
4 | from ..utils.markdown import SpoilerLabel
5 |
6 | class TextualInversionChoice:
7 | def __init__(self, colab, out:Output = None, default_path:str = "/content/embeddings/"):
8 | self.colab = colab
9 | if out is None: out = EmptyOutput()
10 | self.out = out
11 | self.tokens = list()
12 |
13 | self.tooltip_label = SpoilerLabel("Tooltip", "Paste a path to Textual Inversion .pt file, or path to a folder containig them.")
14 | self.path = Text(description="TI:", placeholder='Path to file or folder...', layout=Layout(width="50%"))
15 | self.path.description_tooltip = "Path to a Textual Inversion file or root folder"
16 | if default_path is not None: self.path.value = default_path
17 |
18 | self.load_button = Button(description="Load", button_style='success')
19 | def on_load_clicked(b):
20 | self.out.clear_output()
21 | with self.out:
22 | self.load(self.colab.pipe)
23 | self.load_button.on_click(on_load_clicked)
24 |
25 | self.token_list = Combobox(placeholder="Start typing token", description="Tokens:", ensure_option=True)
26 |
27 | def load(self, pipe):
28 | if self.path.value == "":
29 | raise ValueError("Text field shouldn't be empty")
30 |
31 | if os.path.isfile(self.path.value):
32 | dir, filename = os.path.split(self.path.value)
33 | self.__load_textual_inversion(self.colab.pipe, dir, filename)
34 | elif os.path.isdir(self.path.value):
35 | self.__load_textual_inversions_from_folder(self.colab.pipe, self.path.value)
36 |
37 | #TODO load from file: https://huggingface.co/docs/diffusers/api/loaders/textual_inversion
38 | def __load_textual_inversion(self, pipe, path: str, filename: str):
39 | self.colab.pipe.load_textual_inversion(path, weight_name=filename)
40 |
41 | token = f"<{os.path.splitext(filename)[0]}>"
42 | self.__add_token(token)
43 | print(token)
44 |
45 | def __load_textual_inversions_from_folder(self, pipe, root_folder: str):
46 | n = 0
47 | for path, subdirs, files in os.walk(root_folder):
48 | for name in files:
49 | try:
50 | if os.path.splitext(name)[1] != ".pt": continue
51 | self.colab.pipe.load_textual_inversion(path, weight_name=name)
52 |
53 | token = f"<{os.path.splitext(name)[0]}>"
54 | self.__add_token(token)
55 | print(f"{path}:\t{token}")
56 | n += 1
57 | except: pass #TODO save exception ?
58 | print(f"{n} items added")
59 |
60 | def __add_token(self, token:str):
61 | self.tokens.append(token)
62 | self.token_list.options = self.tokens
63 |
64 | @property
65 | def render_element(self):
66 | return VBox([self.tooltip_label, self.path, self.load_button, self.token_list])
67 |
68 | def render(self):
69 | display(self.render_element)
70 |
71 | def _ipython_display_(self):
72 | self.render()
73 |
--------------------------------------------------------------------------------
/ColabUI/VaeChoice.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from ipywidgets import Text, Layout, Button, HBox, VBox, Output
4 | from diffusers import AutoencoderKL
5 | from ..utils.downloader import download_ckpt
6 | from ..utils.empty_output import EmptyOutput
7 | from ..utils.markdown import SpoilerLabel
8 |
9 | class VaeChoice:
10 | def __init__(self, colab, out:Output = None, default_id:str = "waifu-diffusion/wd-1-5-beta2"):
11 | self.colab = colab
12 | if out is None: out = EmptyOutput()
13 | self.out = out
14 |
15 | self.tooltip_label = SpoilerLabel("Loading single file", "If vae is in single file, just leave 'Subfolder' field empty.")
16 | self.id_text = Text(description="VAE:", placeholder='Id or path...', layout=Layout(width="35%"))
17 | self.id_text.description_tooltip = "Huggingface model id, url, or a path to a root folder/file"
18 | if default_id is not None: self.id_text.value = default_id
19 | self.subfolder_text = Text(description="Subfolder:", layout=Layout(width="25%"))
20 | self.subfolder_text.description_tooltip = "Subfolder name, if using diffusers-style model"
21 | if default_id is not None: self.subfolder_text.value = "vae"
22 |
23 | self.button = Button(description="Load", button_style='success')
24 |
25 | def on_button_clicked(b):
26 | self.out.clear_output()
27 | with self.out:
28 | self.load_vae(self.colab.pipe)
29 |
30 | self.button.on_click(on_button_clicked)
31 |
32 | def load_vae(self, pipe):
33 | if self.id_text.value == "":
34 | raise ValueError("VAE text field shouldn't be empty")
35 |
36 | url = self.id_text.value
37 | if url.endswith(".safetensors") or url.startswith("https://civitai.com/api/download/models/"):
38 | if not os.path.isfile(url): url = download_ckpt(url) #TODO download dir
39 | vae = AutoencoderKL.from_single_file(url, torch_dtype=torch.float16)
40 | else:
41 | if self.subfolder_text.value == "":
42 | raise ValueError("Subfolder text field shouldn't be empty when loading diffusers-style model")
43 | vae = AutoencoderKL.from_pretrained(url, subfolder=self.subfolder_text.value, torch_dtype=torch.float16)
44 |
45 | pipe.vae = vae.to("cuda")
46 | print(f"{self.id_text.value} VAE loaded")
47 |
48 | @property
49 | def render_element(self):
50 | return VBox([self.tooltip_label, HBox([self.id_text, self.subfolder_text]), self.button])
51 |
52 | def render(self):
53 | display(self.render_element)
54 |
55 | def _ipython_display_(self):
56 | self.render()
57 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Andrey Sugonyaev
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Colab UI for Stable Diffusion
2 |
3 | [](https://colab.research.google.com/github/oneir0mancer/stable-diffusion-diffusers-colab-ui/blob/main/sd_diffusers_colab_ui.ipynb)
4 |
5 | Colab UI for running Stable Diffusion with 🤗 [diffusers](https://github.com/huggingface/diffusers) library.
6 |
7 | This repository aims to create a GUI using native Colab and IPython widgets.
8 | This eliminates the need for gradio WebUI, which [seems to be prohibited](https://github.com/googlecolab/colabtools/issues/3591) now on Google Colab.
9 |
10 | 
11 |
12 | ### Features:
13 | - [X] UI based on IPython widgets
14 | - [X] Stable diffusion 1.x, 2.x, XL
15 | - [X] Load models in from Huggingface, and models in Automatic1111 (*ckpt/safetensors*) format
16 | - [X] Change VAE and sampler
17 | - [X] Load textual inversions
18 | - [x] Load LoRAs
19 | - [x] Img2Img
20 | - [x] Inpainting
21 |
22 | ### SDXL
23 | SDXL is supported to an extent. You can load huggingface models, but loading models in A1111 format may be impossible because of RAM limitation.
24 | Also it looks like textual inversions and loading LoRAs is not supported for now.
25 |
26 | ## Kandinsky
27 | [](https://colab.research.google.com/github/oneir0mancer/stable-diffusion-diffusers-colab-ui/blob/main/sd_kandinsky_colab_ui.ipynb)
28 |
29 | Colab UI for Kandinsky txt2img and image mixing.
30 |
31 | Original repo: https://github.com/ai-forever/Kandinsky-2
32 |
--------------------------------------------------------------------------------
/artists.txt:
--------------------------------------------------------------------------------
1 | Alphonse Mucha
2 | Coles Phillips
3 | Frank Frazetta
4 | Alex Russell Flint
5 | Jeffrey Catherine Jones
6 | Ilya Kuvshinov
7 | Josephine Wall
8 | anne bachelier
9 | Russ Mills
10 | Ivan Bilibin
11 | Helen Frankenthaler
12 | Sergio Aragones
13 | Don Bluth
14 | Carl Barks
15 | Gemma Correll
16 | Tex Avery
17 | Matt Bors
18 | Quentin Blake
19 | Richard Scarry
20 | Alex Toth
21 | Steve Dillon
22 | Leone Frollo
23 | Elsa Beskow
24 | Howard Chaykin
25 | Rafael Albuquerque
26 | Hirohiko Araki
27 | Hiromu Arakawa
28 | Arthur Rackham
29 | Henri Fantin-Latour
30 | Eddie Campbell
31 | Abigail Larson
32 | Kaethe Butcher
33 | Aubrey Beardsley
34 | Kate Greenaway
35 | John Blanche
36 | Simon Bisley
37 |
--------------------------------------------------------------------------------
/docs/ui-example.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oneir0mancer/stable-diffusion-diffusers-colab-ui/fd69311417ddbf3f051d893af445d10e14b6542b/docs/ui-example.jpg
--------------------------------------------------------------------------------
/model_index.json:
--------------------------------------------------------------------------------
1 | {
2 | "Stable Diffusion v1.5" : { "id" : "runwayml/stable-diffusion-v1-5" },
3 | "Stable Diffusion XL" : { "id" : "stabilityai/stable-diffusion-xl-base-1.0" },
4 | "Animagine XL" : { "id" : "cagliostrolab/animagine-xl-3.0" },
5 | "Pony Diffusion XL" : { "id" : "stablediffusionapi/pony-diffusion-v6-xl" },
6 | "Counterfeit-V2.5" : { "id" : "gsdf/Counterfeit-V2.5" },
7 | "Analog Diffusion" : { "id" : "wavymulder/Analog-Diffusion", "trigger": "analog style" },
8 | "Dreamlike Diffusion" : { "id" : "dreamlike-art/dreamlike-diffusion-1.0", "trigger": "dreamlikeart" },
9 | "Dreamlike Photoreal 2.0" : { "id" : "dreamlike-art/dreamlike-photoreal-2.0" },
10 | "Openjourney v4" : { "id" : "prompthero/openjourney-v4" },
11 | "Pastel Mix" : { "id" : "andite/pastel-mix" },
12 | "Pony Diffusion v4" : { "id" : "AstraliteHeart/pony-diffusion-v4" },
13 | "ACertainThing" : { "id" : "JosephusCheung/ACertainThing" },
14 | "BPModel" : { "id" : "Crosstyan/BPModel" },
15 | "SomethingV2.2" : { "id" : "NoCrypt/SomethingV2_2" },
16 | "Anything v3.0" : { "id" : "Linaqruf/anything-v3.0" },
17 | "Anything v4.0 mix" : { "id" : "andite/anything-v4.0" },
18 | "Waifu Diffusion v1.4 " : { "id" : "hakurei/waifu-diffusion" },
19 | "Waifu Diffusion v1.5 " : { "id" : "waifu-diffusion/wd-1-5-beta3" }
20 | }
21 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.5.1
2 | torchvision==0.20.1
3 | xformers
4 | accelerate
5 | diffusers
6 | omegaconf
7 | peft
8 | pytorch-lightning
9 | safetensors
10 | transformers
11 |
--------------------------------------------------------------------------------
/scripts/convert_original_stable_diffusion_to_diffusers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Conversion script for the LDM checkpoints. """
16 |
17 | import argparse
18 |
19 | import torch
20 |
21 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
22 |
23 |
24 | if __name__ == "__main__":
25 | parser = argparse.ArgumentParser()
26 |
27 | parser.add_argument(
28 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
29 | )
30 | # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
31 | parser.add_argument(
32 | "--original_config_file",
33 | default=None,
34 | type=str,
35 | help="The YAML config file corresponding to the original architecture.",
36 | )
37 | parser.add_argument(
38 | "--num_in_channels",
39 | default=None,
40 | type=int,
41 | help="The number of input channels. If `None` number of input channels will be automatically inferred.",
42 | )
43 | parser.add_argument(
44 | "--scheduler_type",
45 | default="pndm",
46 | type=str,
47 | help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
48 | )
49 | parser.add_argument(
50 | "--pipeline_type",
51 | default=None,
52 | type=str,
53 | help=(
54 | "The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'"
55 | ". If `None` pipeline will be automatically inferred."
56 | ),
57 | )
58 | parser.add_argument(
59 | "--image_size",
60 | default=None,
61 | type=int,
62 | help=(
63 | "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
64 | " Base. Use 768 for Stable Diffusion v2."
65 | ),
66 | )
67 | parser.add_argument(
68 | "--prediction_type",
69 | default=None,
70 | type=str,
71 | help=(
72 | "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
73 | " Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2."
74 | ),
75 | )
76 | parser.add_argument(
77 | "--extract_ema",
78 | action="store_true",
79 | help=(
80 | "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
81 | " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
82 | " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
83 | ),
84 | )
85 | parser.add_argument(
86 | "--upcast_attention",
87 | action="store_true",
88 | help=(
89 | "Whether the attention computation should always be upcasted. This is necessary when running stable"
90 | " diffusion 2.1."
91 | ),
92 | )
93 | parser.add_argument(
94 | "--from_safetensors",
95 | action="store_true",
96 | help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
97 | )
98 | parser.add_argument(
99 | "--to_safetensors",
100 | action="store_true",
101 | help="Whether to store pipeline in safetensors format or not.",
102 | )
103 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
104 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
105 | parser.add_argument(
106 | "--stable_unclip",
107 | type=str,
108 | default=None,
109 | required=False,
110 | help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.",
111 | )
112 | parser.add_argument(
113 | "--stable_unclip_prior",
114 | type=str,
115 | default=None,
116 | required=False,
117 | help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.",
118 | )
119 | parser.add_argument(
120 | "--clip_stats_path",
121 | type=str,
122 | help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.",
123 | required=False,
124 | )
125 | parser.add_argument(
126 | "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint."
127 | )
128 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
129 | args = parser.parse_args()
130 |
131 | pipe = download_from_original_stable_diffusion_ckpt(
132 | checkpoint_path=args.checkpoint_path,
133 | original_config_file=args.original_config_file,
134 | image_size=args.image_size,
135 | prediction_type=args.prediction_type,
136 | model_type=args.pipeline_type,
137 | extract_ema=args.extract_ema,
138 | scheduler_type=args.scheduler_type,
139 | num_in_channels=args.num_in_channels,
140 | upcast_attention=args.upcast_attention,
141 | from_safetensors=args.from_safetensors,
142 | device=args.device,
143 | stable_unclip=args.stable_unclip,
144 | stable_unclip_prior=args.stable_unclip_prior,
145 | clip_stats_path=args.clip_stats_path,
146 | controlnet=args.controlnet,
147 | )
148 |
149 | if args.half:
150 | pipe.to(torch_dtype=torch.float16)
151 |
152 | if args.controlnet:
153 | # only save the controlnet model
154 | pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
155 | else:
156 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
--------------------------------------------------------------------------------
/scripts/convert_vae_pt_to_diffusers.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import io
3 |
4 | import requests
5 | import torch
6 | from omegaconf import OmegaConf
7 |
8 | from diffusers import AutoencoderKL
9 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
10 | assign_to_checkpoint,
11 | conv_attn_to_linear,
12 | create_vae_diffusers_config,
13 | renew_vae_attention_paths,
14 | renew_vae_resnet_paths,
15 | )
16 |
17 |
18 | def custom_convert_ldm_vae_checkpoint(checkpoint, config):
19 | vae_state_dict = checkpoint
20 |
21 | new_checkpoint = {}
22 |
23 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
24 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
25 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
26 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
27 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
28 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
29 |
30 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
31 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
32 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
33 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
34 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
35 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
36 |
37 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
38 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
39 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
40 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
41 |
42 | # Retrieves the keys for the encoder down blocks only
43 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
44 | down_blocks = {
45 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
46 | }
47 |
48 | # Retrieves the keys for the decoder up blocks only
49 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
50 | up_blocks = {
51 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
52 | }
53 |
54 | for i in range(num_down_blocks):
55 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
56 |
57 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
58 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
59 | f"encoder.down.{i}.downsample.conv.weight"
60 | )
61 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
62 | f"encoder.down.{i}.downsample.conv.bias"
63 | )
64 |
65 | paths = renew_vae_resnet_paths(resnets)
66 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
67 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
68 |
69 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
70 | num_mid_res_blocks = 2
71 | for i in range(1, num_mid_res_blocks + 1):
72 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
73 |
74 | paths = renew_vae_resnet_paths(resnets)
75 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
76 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
77 |
78 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
79 | paths = renew_vae_attention_paths(mid_attentions)
80 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
81 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
82 | conv_attn_to_linear(new_checkpoint)
83 |
84 | for i in range(num_up_blocks):
85 | block_id = num_up_blocks - 1 - i
86 | resnets = [
87 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
88 | ]
89 |
90 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
91 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
92 | f"decoder.up.{block_id}.upsample.conv.weight"
93 | ]
94 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
95 | f"decoder.up.{block_id}.upsample.conv.bias"
96 | ]
97 |
98 | paths = renew_vae_resnet_paths(resnets)
99 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
100 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
101 |
102 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
103 | num_mid_res_blocks = 2
104 | for i in range(1, num_mid_res_blocks + 1):
105 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
106 |
107 | paths = renew_vae_resnet_paths(resnets)
108 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
109 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
110 |
111 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
112 | paths = renew_vae_attention_paths(mid_attentions)
113 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
114 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
115 | conv_attn_to_linear(new_checkpoint)
116 | return new_checkpoint
117 |
118 |
119 | def vae_pt_to_vae_diffuser(
120 | checkpoint_path: str,
121 | output_path: str,
122 | ):
123 | # Only support V1
124 | r = requests.get(
125 | " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
126 | )
127 | io_obj = io.BytesIO(r.content)
128 |
129 | original_config = OmegaConf.load(io_obj)
130 | image_size = 512
131 | device = "cuda" if torch.cuda.is_available() else "cpu"
132 | checkpoint = torch.load(checkpoint_path, map_location=device)
133 |
134 | # Convert the VAE model.
135 | vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
136 | converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint["state_dict"], vae_config)
137 |
138 | vae = AutoencoderKL(**vae_config)
139 | vae.load_state_dict(converted_vae_checkpoint)
140 | vae.save_pretrained(output_path)
141 |
142 |
143 | if __name__ == "__main__":
144 | parser = argparse.ArgumentParser()
145 |
146 | parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
147 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
148 |
149 | args = parser.parse_args()
150 |
151 | vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path)
--------------------------------------------------------------------------------
/sd_diffusers_colab_ui.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github",
7 | "colab_type": "text"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {
17 | "cellView": "form",
18 | "id": "x-z4MGy9rz9I"
19 | },
20 | "outputs": [],
21 | "source": [
22 | "#@title #Install dependencies\n",
23 | "!git clone https://github.com/oneir0mancer/stable-diffusion-diffusers-colab-ui.git StableDiffusionUi\n",
24 | "\n",
25 | "!pip install -r StableDiffusionUi/requirements.txt\n",
26 | "!wget https://github.com/huggingface/diffusers/raw/refs/heads/main/examples/community/lpw_stable_diffusion_xl.py -O StableDiffusionUi/lpw_stable_diffusion_xl.py\n",
27 | "!wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml\n",
28 | "\n",
29 | "!mkdir -p outputs/{txt2img,img2img}\n",
30 | "!git clone https://huggingface.co/embed/negative /content/embeddings/negative\n",
31 | "!git clone https://huggingface.co/embed/lora /content/Lora/positive\n"
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "metadata": {
37 | "id": "QRvNp3ElcaA3"
38 | },
39 | "source": [
40 | "# Creating model pipeline"
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "metadata": {
46 | "id": "dVlvLpShS9eo"
47 | },
48 | "source": [
49 | "\n",
50 | "\n",
51 | "Description
TLDR: use dropdown or just paste stuff
\n",
52 | "\n",
53 | "---\n",
54 | "\n",
55 | "You can just choose a model from a dropdown of popular models, which will also give you a link to a model page on [huggingface](https://huggingface.co) and any trigger words you need to use in the prompt.\n",
56 | "\n",
57 | "Alternatively, you can paste a **model_id**, **url**, or **local path** for any model from huggingface:\n",
58 | "\n",
59 | "* **model_id** looks like this: `gsdf/Counterfeit-V2.5`.\n",
60 | "
You can use it for any *proper* diffusers model in huggingface, i.e. ones that have a lot of folders in their **Files**, and a file called **model_index.json**\n",
61 | "
`🔲 A1111 format` toggle should be **off**.\n",
62 | "\n",
63 | "* **url** from HuggingFace may look like this:\n",
64 | "```\n",
65 | "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors\n",
66 | "```\n",
67 | "Use it when the model is in Automatic1111 format (as in, just a single file).\n",
68 | "
`✅ A1111 format` toggle should be **on**.\n",
69 | "\n",
70 | "* **url** from CivitAI should look like this:\n",
71 | "```\n",
72 | "https://civitai.com/api/download/models/XXXXXX\n",
73 | "```\n",
74 | "(the link you get from \"Download\" button)\n",
75 | "
`✅ A1111 format` toggle should be **on**.\n",
76 | "\n",
77 | "* **local path** is just a path to a model folder (if diffusers format) or model file (if A1111 format).\n",
78 | "
Set toggle accordingly.\n",
79 | "\n",
80 | " "
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "metadata": {
86 | "id": "dviyU3E6jK01"
87 | },
88 | "source": [
89 | "\n",
90 | "Where to find urls
\n",
91 | "\n",
92 | "A good place to find model urls is [camenduru colabs repo](https://https://github.com/camenduru/stable-diffusion-webui-colab), just look for a line starting with `!aria2c ...`\n",
93 | "\n",
94 | " "
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {
101 | "cellView": "form",
102 | "id": "k_TR5dUwznN9"
103 | },
104 | "outputs": [],
105 | "source": [
106 | "#@title Render model choice UI\n",
107 | "\n",
108 | "from StableDiffusionUi.ColabUI.ColabWrapper import ColabWrapper\n",
109 | "\n",
110 | "colab = ColabWrapper(\"outputs/txt2img\")\n",
111 | "colab.render_model_index(\"/content/StableDiffusionUi/model_index.json\")"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {
118 | "cellView": "form",
119 | "id": "IBdm3HvI02Yo"
120 | },
121 | "outputs": [],
122 | "source": [
123 | "#@title Load chosen SD model\n",
124 | "\n",
125 | "#TODO we have to use this until LPW properly supports clip skip\n",
126 | "from StableDiffusionUi.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline\n",
127 | "colab.load_model(StableDiffusionLongPromptWeightingPipeline, custom_pipeline=None)"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "source": [
133 | "#@title (*Optional*) Run this instead to load SDXL model\n",
134 | "from diffusers import StableDiffusionXLPipeline\n",
135 | "from StableDiffusionUi.lpw_stable_diffusion_xl import SDXLLongPromptWeightingPipeline\n",
136 | "\n",
137 | "colab.load_model(SDXLLongPromptWeightingPipeline, custom_pipeline=\"lpw_stable_diffusion_xl\")"
138 | ],
139 | "metadata": {
140 | "cellView": "form",
141 | "id": "AG4ChqboGDO_"
142 | },
143 | "execution_count": null,
144 | "outputs": []
145 | },
146 | {
147 | "cell_type": "markdown",
148 | "metadata": {
149 | "id": "pAJzkJWpeOxk"
150 | },
151 | "source": [
152 | "\n",
153 | "Prompting detail
\n",
154 | "\n",
155 | "I use [this](https://github.com/huggingface/diffusers/tree/main/examples/community#long-prompt-weighting-stable-diffusion) custom pipeline which doesn't have a token length limit and allows to use weights in prompt.\n",
156 | "\n",
157 | " "
158 | ]
159 | },
160 | {
161 | "cell_type": "markdown",
162 | "metadata": {
163 | "id": "hLmzZmRIZMpC"
164 | },
165 | "source": [
166 | "# Generating images"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "metadata": {
173 | "cellView": "form",
174 | "id": "oFruEVcl4ODt"
175 | },
176 | "outputs": [],
177 | "source": [
178 | "#@title Render UI\n",
179 | "#@markdown ---\n",
180 | "from StableDiffusionUi.ColabUI.DiffusionPipelineUI import DiffusionPipelineUI\n",
181 | "\n",
182 | "ui = DiffusionPipelineUI()\n",
183 | "colab.render_generation_ui(ui)"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": null,
189 | "metadata": {
190 | "id": "Y43YLjY7F7NI",
191 | "cellView": "form"
192 | },
193 | "outputs": [],
194 | "source": [
195 | "#@title Run this cell to generate images\n",
196 | "colab.generate()"
197 | ]
198 | },
199 | {
200 | "cell_type": "markdown",
201 | "metadata": {
202 | "id": "gthYZKknKsw7"
203 | },
204 | "source": [
205 | "#Img2Img"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "source": [
211 | "#@title Img2img refiner UI\n",
212 | "# We have to load custom pipeline ourselves because you can't pass it when reusing pipe components\n",
213 | "from diffusers.utils import get_class_from_dynamic_module\n",
214 | "from diffusers import AutoPipelineForImage2Image\n",
215 | "\n",
216 | "if colab.custom_pipeline is not None:\n",
217 | " Interface = get_class_from_dynamic_module(colab.custom_pipeline, module_file=f\"{colab.custom_pipeline}.py\")\n",
218 | "else:\n",
219 | " Interface = StableDiffusionLongPromptWeightingPipeline #AutoPipelineForImage2Image\n",
220 | "\n",
221 | "colab.render_refiner_ui(Interface)"
222 | ],
223 | "metadata": {
224 | "cellView": "form",
225 | "id": "d9oniqPLdZ4x"
226 | },
227 | "execution_count": null,
228 | "outputs": []
229 | },
230 | {
231 | "cell_type": "code",
232 | "source": [
233 | "#@title Run img2img\n",
234 | "colab.refiner_generate(output_dir=\"outputs/img2img\")"
235 | ],
236 | "metadata": {
237 | "cellView": "form",
238 | "id": "zyk-JX0OeKvB"
239 | },
240 | "execution_count": null,
241 | "outputs": []
242 | },
243 | {
244 | "cell_type": "markdown",
245 | "source": [
246 | "#Inpainting"
247 | ],
248 | "metadata": {
249 | "id": "o7aDqCk_EkwD"
250 | }
251 | },
252 | {
253 | "cell_type": "code",
254 | "source": [
255 | "#@title Setup\n",
256 | "#@markdown If UI doesn't show you may need to restart runtime. This stuff is experimantal and may not work properly.\n",
257 | "from google.colab import output\n",
258 | "output.enable_custom_widget_manager()\n",
259 | "\n",
260 | "!pip install ipycanvas"
261 | ],
262 | "metadata": {
263 | "cellView": "form",
264 | "id": "q0uVdOA7EnRj"
265 | },
266 | "execution_count": null,
267 | "outputs": []
268 | },
269 | {
270 | "cell_type": "code",
271 | "source": [
272 | "#@title Inpainting UI\n",
273 | "#@markdown Click \"Save\" button before running inpaint.\n",
274 | "\n",
275 | "# We have to load custom pipeline ourselves because you can't pass it when reusing pipe components\n",
276 | "from diffusers.utils import get_class_from_dynamic_module\n",
277 | "from diffusers import AutoPipelineForInpainting\n",
278 | "from StableDiffusionUi.ColabUI.InpaintRefinerUI import InpaintRefinerUI\n",
279 | "\n",
280 | "if colab.custom_pipeline is not None:\n",
281 | " Interface = get_class_from_dynamic_module(colab.custom_pipeline, module_file=f\"{colab.custom_pipeline}.py\")\n",
282 | "else:\n",
283 | " Interface = StableDiffusionLongPromptWeightingPipeline #TODO actually need to check if we use XL or not\n",
284 | "\n",
285 | "inpaint_ui = InpaintRefinerUI(colab.ui)\n",
286 | "colab.render_inpaint_ui(Interface, inpaint_ui)"
287 | ],
288 | "metadata": {
289 | "cellView": "form",
290 | "id": "SGb1bthqEwvi"
291 | },
292 | "execution_count": null,
293 | "outputs": []
294 | },
295 | {
296 | "cell_type": "code",
297 | "source": [
298 | "#@title Run inpaint\n",
299 | "colab.inpaint_generate(output_dir=\"outputs/img2img\")"
300 | ],
301 | "metadata": {
302 | "cellView": "form",
303 | "id": "zvi2DqRhFZry"
304 | },
305 | "execution_count": null,
306 | "outputs": []
307 | },
308 | {
309 | "cell_type": "markdown",
310 | "source": [
311 | "#Conversion"
312 | ],
313 | "metadata": {
314 | "id": "mHUsrD28erB4"
315 | }
316 | },
317 | {
318 | "cell_type": "markdown",
319 | "metadata": {
320 | "id": "kRUTJXoBc4tG"
321 | },
322 | "source": [
323 | "## Converting models from ckpt\n",
324 | "You shouldn't really need this, because diffusers now support loading from A1111 format.\n",
325 | "\n",
326 | "But if you want, you can download a model and convert it to diffusers format."
327 | ]
328 | },
329 | {
330 | "cell_type": "markdown",
331 | "metadata": {
332 | "id": "av9urfOqi7Ak"
333 | },
334 | "source": [
335 | "NOTE: A good place to find model urls is [camenduru colabs repo](https://https://github.com/camenduru/stable-diffusion-webui-colab), just look for a line starting with `!aria2c ...`\n"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": null,
341 | "metadata": {
342 | "cellView": "form",
343 | "id": "w1JZPBKId6Zv"
344 | },
345 | "outputs": [],
346 | "source": [
347 | "#@markdown ##Download ckpt file\n",
348 | "#@markdown Download ckpt/safetensors file from huggingface url.\n",
349 | "\n",
350 | "#@markdown ---\n",
351 | "import os\n",
352 | "\n",
353 | "url = \"https://huggingface.co/mekabu/MagicalMix_v2/resolve/main/MagicalMix_v2.safetensors\" #@param {type:\"string\"}\n",
354 | "ckpt_dump_folder = \"/content/models_sd/\" #@param {type:\"string\"}\n",
355 | "model_name = url.split('/')[-1]\n",
356 | "\n",
357 | "os.system(f\"mkdir -p {ckpt_dump_folder}\")\n",
358 | "bashCommand = f\"wget {url} -P {ckpt_dump_folder} --content-disposition\"\n",
359 | "os.system(bashCommand)"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": null,
365 | "metadata": {
366 | "cellView": "form",
367 | "id": "dOn6tR_4eJ8k"
368 | },
369 | "outputs": [],
370 | "source": [
371 | "#@markdown ##(Optional) Or just upload file\n",
372 | "from google.colab import files\n",
373 | "f = files.upload()\n",
374 | "\n",
375 | "ckpt_dump_folder = \"/content/\"\n",
376 | "for key in f.keys():\n",
377 | " model_name = key\n",
378 | " break"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": null,
384 | "metadata": {
385 | "cellView": "form",
386 | "id": "N1GOOJrej6tW"
387 | },
388 | "outputs": [],
389 | "source": [
390 | "#@markdown ##Run conversion script\n",
391 | "#@markdown Paste a path to where you want this script to cache a model into `dump_path`.\n",
392 | "\n",
393 | "#@markdown Keep `override_path` empty unless you uploaded your file to some custom directory.\n",
394 | "\n",
395 | "import os\n",
396 | "import torch\n",
397 | "from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt\n",
398 | "\n",
399 | "from_safetensors = True #@param {type:\"boolean\"}\n",
400 | "override_path = \"\" #@param {type:\"string\"}\n",
401 | "if override_path != \"\":\n",
402 | " checkpoint_path = override_path\n",
403 | "else:\n",
404 | " checkpoint_path = os.path.join(ckpt_dump_folder, model_name)\n",
405 | "\n",
406 | "pipe = download_from_original_stable_diffusion_ckpt(\n",
407 | " checkpoint_path=checkpoint_path,\n",
408 | " original_config_file = \"/content/v1-inference.yaml\",\n",
409 | " from_safetensors=from_safetensors,\n",
410 | " )\n",
411 | "\n",
412 | "dump_path=\"models/ModelName/\" #@param {type:\"string\"}\n",
413 | "pipe.save_pretrained(dump_path, safe_serialization=from_safetensors)\n",
414 | "\n",
415 | "pipe = pipe.to(\"cuda\")\n",
416 | "pipe.safety_checker = None\n",
417 | "pipe.to(torch_dtype=torch.float16)"
418 | ]
419 | },
420 | {
421 | "cell_type": "markdown",
422 | "metadata": {
423 | "id": "ii9UDlWQnIhg"
424 | },
425 | "source": [
426 | "## Converting VAE"
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": null,
432 | "metadata": {
433 | "cellView": "form",
434 | "id": "916cyCYPnNVt"
435 | },
436 | "outputs": [],
437 | "source": [
438 | "#@markdown ##Download vae file\n",
439 | "#@markdown Basically, the same thing as model files\n",
440 | "\n",
441 | "#@markdown ---\n",
442 | "import os\n",
443 | "\n",
444 | "url = \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt\" #@param {type:\"string\"}\n",
445 | "ckpt_dump_folder = \"/content/models_sd/\" #@param {type:\"string\"}\n",
446 | "model_name = url.split('/')[-1]\n",
447 | "\n",
448 | "os.system(f\"mkdir -p {ckpt_dump_folder}\")\n",
449 | "bashCommand = f\"wget {url} -P {ckpt_dump_folder} --content-disposition\"\n",
450 | "os.system(bashCommand)"
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": null,
456 | "metadata": {
457 | "cellView": "form",
458 | "id": "dWC8UpJ3m8Fa"
459 | },
460 | "outputs": [],
461 | "source": [
462 | "#@markdown Paste a path to where you want this script to cache a vae into `dump_path`.\n",
463 | "from_safetensors = False #@param {type:\"boolean\"}\n",
464 | "dump_path=\"vae/VaeName/\" #@param {type:\"string\"}\n",
465 | "\n",
466 | "checkpoint_path = os.path.join(ckpt_dump_folder, model_name)\n",
467 | "bashCommand = f\"python /content/StableDiffusionUi/scripts/convert_vae_pt_to_diffusers.py --vae_pt_path {checkpoint_path} --dump_path {dump_path}\"\n",
468 | "if from_safetensors:\n",
469 | " bashCommand += \" --from_safetensors\"\n",
470 | "os.system(bashCommand)"
471 | ]
472 | },
473 | {
474 | "cell_type": "markdown",
475 | "source": [
476 | "#Flush model\n"
477 | ],
478 | "metadata": {
479 | "id": "33zsLV1d0CKE"
480 | }
481 | },
482 | {
483 | "cell_type": "code",
484 | "source": [
485 | "import gc\n",
486 | "import torch\n",
487 | "\n",
488 | "del colab.pipe\n",
489 | "torch.cuda.empty_cache()\n",
490 | "gc.collect()"
491 | ],
492 | "metadata": {
493 | "id": "aWqoBQP60OyY"
494 | },
495 | "execution_count": null,
496 | "outputs": []
497 | }
498 | ],
499 | "metadata": {
500 | "accelerator": "GPU",
501 | "colab": {
502 | "collapsed_sections": [
503 | "gthYZKknKsw7",
504 | "o7aDqCk_EkwD",
505 | "mHUsrD28erB4",
506 | "kRUTJXoBc4tG",
507 | "ii9UDlWQnIhg",
508 | "33zsLV1d0CKE"
509 | ],
510 | "provenance": [],
511 | "include_colab_link": true
512 | },
513 | "gpuClass": "standard",
514 | "kernelspec": {
515 | "display_name": "Python 3",
516 | "name": "python3"
517 | },
518 | "language_info": {
519 | "name": "python"
520 | }
521 | },
522 | "nbformat": 4,
523 | "nbformat_minor": 0
524 | }
--------------------------------------------------------------------------------
/sd_diffusers_img2img_ui.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github",
7 | "colab_type": "text"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {
17 | "cellView": "form",
18 | "id": "x-z4MGy9rz9I"
19 | },
20 | "outputs": [],
21 | "source": [
22 | "#@title #Install dependencies\n",
23 | "!git clone https://github.com/oneir0mancer/stable-diffusion-diffusers-colab-ui.git StableDiffusionUi\n",
24 | "\n",
25 | "!pip install -r StableDiffusionUi/requirements.txt\n",
26 | "!wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml\n",
27 | "\n",
28 | "!mkdir -p outputs/{txt2img,img2img}\n",
29 | "!git clone https://huggingface.co/embed/negative /content/embeddings/negative\n",
30 | "!git clone https://huggingface.co/embed/lora /content/Lora/positive\n"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {
36 | "id": "QRvNp3ElcaA3"
37 | },
38 | "source": [
39 | "# Creating model pipeline"
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {
45 | "id": "dVlvLpShS9eo"
46 | },
47 | "source": [
48 | "\n",
49 | "\n",
50 | "Description
TLDR: use dropdown or just paste
\n",
51 | "\n",
52 | "---\n",
53 | "\n",
54 | "You can just choose a model from a dropdown of popular models, which will also give you a link to a model page on [huggingface](https://huggingface.co) and any trigger words you need to use in the prompt.\n",
55 | "\n",
56 | "Alternatively, you can paste a model_id, url, or local path for any model from huggingface:\n",
57 | "\n",
58 | "* **model_id** looks like this: `gsdf/Counterfeit-V2.5`.\n",
59 | "
You can use it for any *proper* diffusers model in huggingface, i.e. ones that have a lot of folders in their **Files**, and a file called **model_index.json**\n",
60 | "
`A1111 format` toggle should be **off**.\n",
61 | "\n",
62 | "* **url** from HuggingFace may look like this:\n",
63 | "```\n",
64 | "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors\n",
65 | "```\n",
66 | "Use it when the model is in Automatic1111 format (as in, just a single file).\n",
67 | "
`✅ A1111 format` toggle should be **on**.\n",
68 | "\n",
69 | "* **url** from CivitAI should look like this:\n",
70 | "```\n",
71 | "https://civitai.com/api/download/models/XXXXXX\n",
72 | "```\n",
73 | "(the link you get from \"Download\" button)\n",
74 | "
`✅ A1111 format` toggle should be **on**.\n",
75 | "\n",
76 | "* **local path** is just a path to a model folder (if diffusers format) or model file (if A1111 format).\n",
77 | "
Set toggle accordingly.\n",
78 | "\n",
79 | " "
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {
85 | "id": "dviyU3E6jK01"
86 | },
87 | "source": [
88 | "\n",
89 | "Where to find urls
\n",
90 | "\n",
91 | "A good place to find model urls is [camenduru colabs repo](https://https://github.com/camenduru/stable-diffusion-webui-colab), just look for a line starting with `!aria2c ...`\n",
92 | "\n",
93 | " "
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {
100 | "cellView": "form",
101 | "id": "k_TR5dUwznN9"
102 | },
103 | "outputs": [],
104 | "source": [
105 | "#@title Render model choice UI\n",
106 | "from StableDiffusionUi.ColabUI.ColabWrapper import ColabWrapper\n",
107 | "\n",
108 | "colab = ColabWrapper(\"outputs/img2img\")\n",
109 | "colab.render_model_index(\"/content/StableDiffusionUi/model_index.json\")"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {
116 | "cellView": "form",
117 | "id": "IBdm3HvI02Yo"
118 | },
119 | "outputs": [],
120 | "source": [
121 | "#@title Load chosen model\n",
122 | "from diffusers import StableDiffusionImg2ImgPipeline\n",
123 | "colab.load_model(StableDiffusionImg2ImgPipeline)"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "source": [
129 | "#@title Render Settings UI\n",
130 | "colab.render_settings()"
131 | ],
132 | "metadata": {
133 | "cellView": "form",
134 | "id": "Yz3c-nayq431"
135 | },
136 | "execution_count": null,
137 | "outputs": []
138 | },
139 | {
140 | "cell_type": "markdown",
141 | "metadata": {
142 | "id": "hLmzZmRIZMpC"
143 | },
144 | "source": [
145 | "# Generating images"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": null,
151 | "metadata": {
152 | "cellView": "form",
153 | "id": "oFruEVcl4ODt"
154 | },
155 | "outputs": [],
156 | "source": [
157 | "#@title Render UI\n",
158 | "#@markdown You don't need to run this cell again unless you want to change these settings\n",
159 | "save_images = True #@param {type:\"boolean\"}\n",
160 | "display_previewes = True #@param {type:\"boolean\"}\n",
161 | "#@markdown ---\n",
162 | "from StableDiffusionUi.ColabUI.DiffusionImg2ImgUI import DiffusionImg2ImgUI\n",
163 | "\n",
164 | "ui = DiffusionImg2ImgUI()\n",
165 | "colab.render_generation_ui(ui)"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "metadata": {
172 | "cellView": "form",
173 | "id": "Y43YLjY7F7NI"
174 | },
175 | "outputs": [],
176 | "source": [
177 | "#@title Run this cell to generate images\n",
178 | "colab.generate(save_images, display_previewes)"
179 | ]
180 | },
181 | {
182 | "cell_type": "markdown",
183 | "metadata": {
184 | "id": "gthYZKknKsw7"
185 | },
186 | "source": [
187 | "#Utils"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": null,
193 | "metadata": {
194 | "cellView": "form",
195 | "id": "9vn2Z5tNIAiv"
196 | },
197 | "outputs": [],
198 | "source": [
199 | "#@markdown ## Image viewer\n",
200 | "#@markdown Run this cell to view last results in full size\n",
201 | "from IPython.display import clear_output\n",
202 | "import ipywidgets as widgets\n",
203 | "\n",
204 | "slider = widgets.IntSlider(max=len(results.images)-1)\n",
205 | "\n",
206 | "def handler(change):\n",
207 | " slider.max = len(results.images)-1\n",
208 | " if change.new > slider.max: change.new = slider.max\n",
209 | " clear_output(wait=True)\n",
210 | " display(slider, results.images[change.new])\n",
211 | "\n",
212 | "slider.observe(handler, names='value')\n",
213 | "\n",
214 | "display(slider, results.images[slider.value])"
215 | ]
216 | },
217 | {
218 | "cell_type": "markdown",
219 | "metadata": {
220 | "id": "kRUTJXoBc4tG"
221 | },
222 | "source": [
223 | "## Converting models from ckpt\n",
224 | "You shouldn't really need this, because diffusers now support loading from A1111 format.\n",
225 | "\n",
226 | "But if you want, you can download a model and convert it to diffusers format."
227 | ]
228 | },
229 | {
230 | "cell_type": "markdown",
231 | "metadata": {
232 | "id": "av9urfOqi7Ak"
233 | },
234 | "source": [
235 | "NOTE: A good place to find model urls is [camenduru colabs repo](https://https://github.com/camenduru/stable-diffusion-webui-colab), just look for a line starting with `!aria2c ...`\n"
236 | ]
237 | },
238 | {
239 | "cell_type": "code",
240 | "execution_count": null,
241 | "metadata": {
242 | "cellView": "form",
243 | "id": "w1JZPBKId6Zv"
244 | },
245 | "outputs": [],
246 | "source": [
247 | "#@markdown ##Download ckpt file\n",
248 | "#@markdown Download ckpt/safetensors file from huggingface url.\n",
249 | "\n",
250 | "#@markdown ---\n",
251 | "import os\n",
252 | "\n",
253 | "url = \"https://huggingface.co/mekabu/MagicalMix_v2/resolve/main/MagicalMix_v2.safetensors\" #@param {type:\"string\"}\n",
254 | "ckpt_dump_folder = \"/content/models_sd/\" #@param {type:\"string\"}\n",
255 | "model_name = url.split('/')[-1]\n",
256 | "\n",
257 | "os.system(f\"mkdir -p {ckpt_dump_folder}\")\n",
258 | "bashCommand = f\"wget {url} -P {ckpt_dump_folder} --content-disposition\"\n",
259 | "os.system(bashCommand)"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": null,
265 | "metadata": {
266 | "cellView": "form",
267 | "id": "dOn6tR_4eJ8k"
268 | },
269 | "outputs": [],
270 | "source": [
271 | "#@markdown ##(Optional) Or just upload file\n",
272 | "from google.colab import files\n",
273 | "f = files.upload()\n",
274 | "\n",
275 | "ckpt_dump_folder = \"/content/\"\n",
276 | "for key in f.keys():\n",
277 | " model_name = key\n",
278 | " break"
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": null,
284 | "metadata": {
285 | "cellView": "form",
286 | "id": "N1GOOJrej6tW"
287 | },
288 | "outputs": [],
289 | "source": [
290 | "#@markdown ##Run conversion script\n",
291 | "#@markdown Paste a path to where you want this script to cache a model into `dump_path`.\n",
292 | "\n",
293 | "#@markdown Keep `override_path` empty unless you uploaded your file to some custom directory.\n",
294 | "\n",
295 | "import os\n",
296 | "import torch\n",
297 | "from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt\n",
298 | "\n",
299 | "from_safetensors = True #@param {type:\"boolean\"}\n",
300 | "override_path = \"\" #@param {type:\"string\"}\n",
301 | "if override_path != \"\":\n",
302 | " checkpoint_path = override_path\n",
303 | "else:\n",
304 | " checkpoint_path = os.path.join(ckpt_dump_folder, model_name)\n",
305 | "\n",
306 | "pipe = download_from_original_stable_diffusion_ckpt(\n",
307 | " checkpoint_path=checkpoint_path,\n",
308 | " original_config_file = \"/content/v1-inference.yaml\",\n",
309 | " from_safetensors=from_safetensors,\n",
310 | " )\n",
311 | "\n",
312 | "dump_path=\"models/ModelName/\" #@param {type:\"string\"}\n",
313 | "pipe.save_pretrained(dump_path, safe_serialization=from_safetensors)\n",
314 | "\n",
315 | "pipe = pipe.to(\"cuda\")\n",
316 | "pipe.safety_checker = None\n",
317 | "pipe.to(torch_dtype=torch.float16)"
318 | ]
319 | },
320 | {
321 | "cell_type": "markdown",
322 | "metadata": {
323 | "id": "ii9UDlWQnIhg"
324 | },
325 | "source": [
326 | "## Converting VAE"
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": null,
332 | "metadata": {
333 | "cellView": "form",
334 | "id": "916cyCYPnNVt"
335 | },
336 | "outputs": [],
337 | "source": [
338 | "#@markdown ##Download vae file\n",
339 | "#@markdown Basically, the same thing as model files\n",
340 | "\n",
341 | "#@markdown ---\n",
342 | "import os\n",
343 | "\n",
344 | "url = \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt\" #@param {type:\"string\"}\n",
345 | "ckpt_dump_folder = \"/content/models_sd/\" #@param {type:\"string\"}\n",
346 | "model_name = url.split('/')[-1]\n",
347 | "\n",
348 | "os.system(f\"mkdir -p {ckpt_dump_folder}\")\n",
349 | "bashCommand = f\"wget {url} -P {ckpt_dump_folder} --content-disposition\"\n",
350 | "os.system(bashCommand)"
351 | ]
352 | },
353 | {
354 | "cell_type": "code",
355 | "execution_count": null,
356 | "metadata": {
357 | "cellView": "form",
358 | "id": "dWC8UpJ3m8Fa"
359 | },
360 | "outputs": [],
361 | "source": [
362 | "#@markdown Paste a path to where you want this script to cache a vae into `dump_path`.\n",
363 | "from_safetensors = False #@param {type:\"boolean\"}\n",
364 | "dump_path=\"vae/VaeName/\" #@param {type:\"string\"}\n",
365 | "\n",
366 | "checkpoint_path = os.path.join(ckpt_dump_folder, model_name)\n",
367 | "bashCommand = f\"python /content/StableDiffusionUi/scripts/convert_vae_pt_to_diffusers.py --vae_pt_path {checkpoint_path} --dump_path {dump_path}\"\n",
368 | "if from_safetensors:\n",
369 | " bashCommand += \" --from_safetensors\"\n",
370 | "os.system(bashCommand)"
371 | ]
372 | }
373 | ],
374 | "metadata": {
375 | "accelerator": "GPU",
376 | "colab": {
377 | "collapsed_sections": [
378 | "hEwlf_SQXCyt",
379 | "mY0xKEW2lgmj",
380 | "lYurBnHCbZ3o",
381 | "kRUTJXoBc4tG",
382 | "ii9UDlWQnIhg"
383 | ],
384 | "provenance": [],
385 | "include_colab_link": true
386 | },
387 | "gpuClass": "standard",
388 | "kernelspec": {
389 | "display_name": "Python 3",
390 | "name": "python3"
391 | },
392 | "language_info": {
393 | "name": "python"
394 | }
395 | },
396 | "nbformat": 4,
397 | "nbformat_minor": 0
398 | }
--------------------------------------------------------------------------------
/sd_kandinsky_colab_ui.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "toc_visible": true,
8 | "authorship_tag": "ABX9TyNG2pl8jKc8G9gJx57zyOjk",
9 | "include_colab_link": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "language_info": {
16 | "name": "python"
17 | }
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {
23 | "id": "view-in-github",
24 | "colab_type": "text"
25 | },
26 | "source": [
27 | "
"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "source": [
33 | "# Kandinsky"
34 | ],
35 | "metadata": {
36 | "id": "J5jPbraGhH-D"
37 | }
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {
43 | "cellView": "form",
44 | "id": "ayR-stkCUiZh"
45 | },
46 | "outputs": [],
47 | "source": [
48 | "#@title Install dependencies\n",
49 | "!pip install 'git+https://github.com/ai-forever/Kandinsky-2.git'\n",
50 | "!pip install git+https://github.com/openai/CLIP.git\n",
51 | "!mkdir -p outputs/{txt2img,img2img}\n",
52 | "!git clone https://github.com/oneir0mancer/stable-diffusion-diffusers-colab-ui.git StableDiffusionUi"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "source": [
58 | "#@title Create model\n",
59 | "from kandinsky2 import get_kandinsky2\n",
60 | "\n",
61 | "model = get_kandinsky2('cuda', task_type='text2img', cache_dir='/tmp/kandinsky2', \n",
62 | " model_version='2.1', use_flash_attention=False)\n",
63 | "sampler = \"p_sampler\""
64 | ],
65 | "metadata": {
66 | "cellView": "form",
67 | "id": "05ppJRuPfw56"
68 | },
69 | "execution_count": null,
70 | "outputs": []
71 | },
72 | {
73 | "cell_type": "code",
74 | "source": [
75 | "#@title (Optional) Change sampler \n",
76 | "sampler = \"p_sampler\" #@param [\"p_sampler\", \"ddim_sampler\", \"plms_sampler\"]"
77 | ],
78 | "metadata": {
79 | "cellView": "form",
80 | "id": "OxmVn8OLf1al"
81 | },
82 | "execution_count": null,
83 | "outputs": []
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "source": [
88 | "## Txt2Img"
89 | ],
90 | "metadata": {
91 | "id": "oy0VXcG8iZlj"
92 | }
93 | },
94 | {
95 | "cell_type": "code",
96 | "source": [
97 | "#@title Render UI\n",
98 | "#@markdown You don't need to run this cell again unless you want to change these settings\n",
99 | "save_images = True #@param {type:\"boolean\"}\n",
100 | "display_previewes = True #@param {type:\"boolean\"}\n",
101 | "#@markdown ---\n",
102 | "\n",
103 | "from StableDiffusionUi.ColabUI.KandinskyUI import KandinskyUI\n",
104 | "\n",
105 | "ui = KandinskyUI()\n",
106 | "ui.render()"
107 | ],
108 | "metadata": {
109 | "cellView": "form",
110 | "id": "zFlSy38ef3iV"
111 | },
112 | "execution_count": null,
113 | "outputs": []
114 | },
115 | {
116 | "cell_type": "code",
117 | "source": [
118 | "#@title Run this cell to generate images\n",
119 | "images = ui.generate(model, sampler)\n",
120 | "\n",
121 | "if display_previewes:\n",
122 | " ui.display_image_previews(images)\n",
123 | "\n",
124 | "if save_images:\n",
125 | " for image in results.images:\n",
126 | " image.save(f\"outputs/txt2img/{output_index:05}.png\")\n",
127 | " print(f\"outputs/txt2img/{output_index:05}.png\")\n",
128 | " output_index += 1"
129 | ],
130 | "metadata": {
131 | "cellView": "form",
132 | "id": "z5dCbkQ0f9KV"
133 | },
134 | "execution_count": null,
135 | "outputs": []
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "source": [
140 | "## Image mixing\n",
141 | "(WIP)"
142 | ],
143 | "metadata": {
144 | "id": "wQPK2SDpic70"
145 | }
146 | },
147 | {
148 | "cell_type": "code",
149 | "source": [
150 | "#@title Render UI\n",
151 | "import io\n",
152 | "import ipywidgets as widgets\n",
153 | "from PIL import Image\n",
154 | "\n",
155 | "text_img_a = widgets.Text(placeholder=\"Path to 1st image\")\n",
156 | "text_img_b = widgets.Text(placeholder=\"Path to 2nd image\")\n",
157 | "slider_img_a = widgets.FloatSlider(max=1, step=0.05)\n",
158 | "slider_img_b = widgets.FloatSlider(max=1, step=0.05)\n",
159 | "text_prompt_a = widgets.Text(placeholder=\"Optional 1st prompt\")\n",
160 | "text_prompt_b = widgets.Text(placeholder=\"Optional 2nd prompt\")\n",
161 | "slider_prompt_a = widgets.FloatSlider(max=1, step=0.05)\n",
162 | "slider_prompt_b = widgets.FloatSlider(max=1, step=0.05)\n",
163 | "\n",
164 | "width_field = widgets.IntText(value=512, layout=widgets.Layout(width='200px'), description=\"width: \")\n",
165 | "height_field = widgets.IntText(value=768, layout=widgets.Layout(width='200px'), description=\"height: \")\n",
166 | "steps_field = widgets.IntSlider(value=20, min=1, max=100, description=\"Steps: \")\n",
167 | "cfg_field = widgets.FloatSlider(value=7, min=1, max=20, step=0.5, description=\"CFG: \")\n",
168 | "\n",
169 | "display(widgets.HBox([text_img_a, text_img_b]), widgets.HBox([slider_img_a, slider_img_b]))\n",
170 | "display(widgets.HBox([text_prompt_a, text_prompt_b]), widgets.HBox([slider_prompt_a, slider_prompt_b]))\n",
171 | "display(widgets.HBox([width_field, height_field]), steps_field, cfg_field)"
172 | ],
173 | "metadata": {
174 | "cellView": "form",
175 | "id": "csc--queiiNE"
176 | },
177 | "execution_count": null,
178 | "outputs": []
179 | },
180 | {
181 | "cell_type": "code",
182 | "source": [
183 | "#@title Run this cell to generate\n",
184 | "def display_image_previews(images):\n",
185 | " wgts = []\n",
186 | " for image in images:\n",
187 | " img_byte_arr = io.BytesIO()\n",
188 | " image.save(img_byte_arr, format='PNG')\n",
189 | " wgt = widgets.Image(\n",
190 | " value=img_byte_arr.getvalue(),\n",
191 | " width=image.width/2,\n",
192 | " height=image.height/2,\n",
193 | " )\n",
194 | " wgts.append(wgt)\n",
195 | " display(widgets.HBox(wgts))\n",
196 | "\n",
197 | "img_a = Image.open(text_img_a.value)\n",
198 | "img_b = Image.open(text_img_b.value)\n",
199 | "images_texts = [text_prompt_a.value, img_a, img_b, text_prompt_b.value]\n",
200 | "weights = [slider_prompt_a.value, slider_img_a.value, slider_img_b.value, slider_prompt_b.value]\n",
201 | "\n",
202 | "images = model.mix_images(\n",
203 | " images_texts, \n",
204 | " weights, \n",
205 | " num_steps=steps_field.value,\n",
206 | " batch_size=1, \n",
207 | " guidance_scale=cfg_field.value,\n",
208 | " h=height_field.value, w=width_field.value,\n",
209 | " sampler=sampler, \n",
210 | " prior_cf_scale=4,\n",
211 | " prior_steps=\"5\"\n",
212 | ")\n",
213 | "display_image_previews([img_a, images[0], img_b])"
214 | ],
215 | "metadata": {
216 | "cellView": "form",
217 | "id": "bz0Ty0b4nA2E"
218 | },
219 | "execution_count": null,
220 | "outputs": []
221 | }
222 | ]
223 | }
--------------------------------------------------------------------------------
/utils/downloader.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import werkzeug
3 | import os
4 | from tqdm.notebook import tqdm
5 |
6 | def download_ckpt(ckpt_url, output_dir=None):
7 | """
8 | download the checkpoint & return file name to be used in diffusers pipeline
9 | additional actions:
10 | - try to get file name from header or url
11 | - download only if file not exist
12 | - download with progress bar
13 | """
14 | HEADERS = {"User-Agent": "github/oneir0mancer/stable-diffusion-diffusers-colab-ui"}
15 | with requests.get(ckpt_url, headers=HEADERS, stream=True) as resp:
16 |
17 | # get file name
18 | MISSING_FILENAME = "missing_name"
19 | if content_disposition := resp.headers.get("Content-Disposition"):
20 | param, options = werkzeug.http.parse_options_header(content_disposition)
21 | if param == "attachment":
22 | filename = options.get("filename", MISSING_FILENAME)
23 | else:
24 | filename = MISSING_FILENAME
25 | else:
26 | filename = os.path.basename(ckpt_url)
27 | fileext = os.path.splitext(filename)[-1]
28 | if fileext == "":
29 | filename = MISSING_FILENAME
30 |
31 | if output_dir is not None: filename = os.path.join(output_dir, filename)
32 |
33 | # download file
34 | if not os.path.exists(filename):
35 | TOTAL_SIZE = int(resp.headers.get("Content-Length", 0))
36 | CHUNK_SIZE = 50 * 1024**2 # 50 MiB or more as colab has high speed internet
37 | with open(filename, mode="wb") as file, tqdm(total=TOTAL_SIZE, desc="download checkpoint", unit="iB", unit_scale=True) as progress_bar:
38 | for chunk in resp.iter_content(chunk_size=CHUNK_SIZE):
39 | progress_bar.update(len(chunk))
40 | file.write(chunk)
41 | return filename
42 |
--------------------------------------------------------------------------------
/utils/empty_output.py:
--------------------------------------------------------------------------------
1 | class EmptyOutput:
2 | """An empty wrapper around output, for when Output widget wasn't provided"""
3 | def __enter__(self): pass
4 |
5 | def __exit__(self, type, value, traceback): pass
6 |
7 | def clear_output(self): pass
--------------------------------------------------------------------------------
/utils/event.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 |
3 | class Event:
4 | def __init__(self):
5 | self.__callbacks = []
6 |
7 | def add_callback(self, callback: Callable):
8 | self.__callbacks.append(callback)
9 |
10 | def remove_callback(self, callback: Callable):
11 | self.__callbacks.remove(callback)
12 |
13 | def clear_callbacks(self):
14 | self.__callbacks.clear()
15 |
16 | def invoke(self, event_args: list = None):
17 | if event_args is None:
18 | for x in self.__callbacks: x()
19 | else:
20 | for x in self.__callbacks: x(event_args)
21 |
22 | @property
23 | def callbacks(self): return self.__callbacks
24 |
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from PIL.PngImagePlugin import PngInfo
3 |
4 | def load_image_metadata(filepath: str):
5 | image = Image.open(filepath)
6 | prompt, etc = image.text["Data"].split("\nNegative: ")
7 | prompt = prompt.replace("\nPrompt: ", "")
8 | negative_prompt, etc = etc.split("\nCGF: ")
9 | # TODO parse other metadata into dict perhabs
10 | return prompt, negative_prompt
11 |
12 | def save_image_with_metadata(image, path, metadata_provider, additional_data = ""):
13 | meta = PngInfo()
14 | meta.add_text("Data", metadata_provider.metadata + additional_data)
15 | image.save(path, pnginfo=meta)
16 |
--------------------------------------------------------------------------------
/utils/kohya_lora_loader.py:
--------------------------------------------------------------------------------
1 | #from: https://gist.github.com/takuma104/e38d683d72b1e448b8d9b3835f7cfa44
2 | import math
3 | import safetensors
4 | import torch
5 | from diffusers import DiffusionPipeline
6 |
7 | """
8 | Kohya's LoRA format Loader for Diffusers
9 | Usage:
10 | ```py
11 | # An usual Diffusers' setup
12 | import torch
13 | from diffusers import StableDiffusionPipeline
14 | pipe = StableDiffusionPipeline.from_pretrained('...',
15 | torch_dtype=torch.float16).to('cuda')
16 | # Import this module
17 | import kohya_lora_loader
18 | # Install LoRA hook. This append apply_loar and remove_loar methods to the pipe.
19 | kohya_lora_loader.install_lora_hook(pipe)
20 | # Load 'lora1.safetensors' file and apply
21 | lora1 = pipe.apply_lora('lora1.safetensors', 1.0)
22 | # You can change alpha
23 | lora1.alpha = 0.5
24 | # Load 'lora2.safetensors' file and apply
25 | lora2 = pipe.apply_lora('lora2.safetensors', 1.0)
26 | # Generate image with lora1 and lora2 applied
27 | pipe(...).images[0]
28 | # Remove lora2
29 | pipe.remove_lora(lora2)
30 | # Generate image with lora1 applied
31 | pipe(...).images[0]
32 | # Uninstall LoRA hook
33 | kohya_lora_loader.uninstall_lora_hook(pipe)
34 | # Generate image with none LoRA applied
35 | pipe(...).images[0]
36 | ```
37 | """
38 |
39 |
40 | # modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17
41 | class LoRAModule(torch.nn.Module):
42 | def __init__(
43 | self, org_module: torch.nn.Module, lora_dim=4, alpha=1.0, multiplier=1.0
44 | ):
45 | """if alpha == 0 or None, alpha is rank (no scaling)."""
46 | super().__init__()
47 |
48 | if org_module.__class__.__name__ == "Conv2d":
49 | in_dim = org_module.in_channels
50 | out_dim = org_module.out_channels
51 | else:
52 | in_dim = org_module.in_features
53 | out_dim = org_module.out_features
54 |
55 | self.lora_dim = lora_dim
56 |
57 | if org_module.__class__.__name__ == "Conv2d":
58 | kernel_size = org_module.kernel_size
59 | stride = org_module.stride
60 | padding = org_module.padding
61 | self.lora_down = torch.nn.Conv2d(
62 | in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
63 | )
64 | self.lora_up = torch.nn.Conv2d(
65 | self.lora_dim, out_dim, (1, 1), (1, 1), bias=False
66 | )
67 | else:
68 | self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
69 | self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
70 |
71 | if alpha is None or alpha == 0:
72 | self.alpha = self.lora_dim
73 | else:
74 | if type(alpha) == torch.Tensor:
75 | alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
76 | self.register_buffer("alpha", torch.tensor(alpha)) # Treatable as a constant.
77 |
78 | # same as microsoft's
79 | torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
80 | torch.nn.init.zeros_(self.lora_up.weight)
81 |
82 | self.multiplier = multiplier
83 |
84 | def forward(self, x):
85 | scale = self.alpha / self.lora_dim
86 | return self.multiplier * scale * self.lora_up(self.lora_down(x))
87 |
88 |
89 | class LoRAModuleContainer(torch.nn.Module):
90 | def __init__(self, hooks, state_dict, multiplier):
91 | super().__init__()
92 | self.multiplier = multiplier
93 |
94 | # Create LoRAModule from state_dict information
95 | for key, value in state_dict.items():
96 | if "lora_down" in key:
97 | lora_name = key.split(".")[0]
98 | lora_dim = value.size()[0]
99 | lora_name_alpha = key.split(".")[0] + '.alpha'
100 | alpha = None
101 | if lora_name_alpha in state_dict:
102 | alpha = state_dict[lora_name_alpha].item()
103 | hook = hooks[lora_name]
104 | lora_module = LoRAModule(
105 | hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier
106 | )
107 | self.register_module(lora_name, lora_module)
108 |
109 | # Load whole LoRA weights
110 | self.load_state_dict(state_dict)
111 |
112 | # Register LoRAModule to LoRAHook
113 | for name, module in self.named_modules():
114 | if module.__class__.__name__ == "LoRAModule":
115 | hook = hooks[name]
116 | hook.append_lora(module)
117 | @property
118 | def alpha(self):
119 | return self.multiplier
120 |
121 | @alpha.setter
122 | def alpha(self, multiplier):
123 | self.multiplier = multiplier
124 | for name, module in self.named_modules():
125 | if module.__class__.__name__ == "LoRAModule":
126 | module.multiplier = multiplier
127 |
128 | def remove_from_hooks(self, hooks):
129 | for name, module in self.named_modules():
130 | if module.__class__.__name__ == "LoRAModule":
131 | hook = hooks[name]
132 | hook.remove_lora(module)
133 | del module
134 |
135 |
136 | class LoRAHook(torch.nn.Module):
137 | """
138 | replaces forward method of the original Linear,
139 | instead of replacing the original Linear module.
140 | """
141 |
142 | def __init__(self):
143 | super().__init__()
144 | self.lora_modules = []
145 |
146 | def install(self, orig_module):
147 | assert not hasattr(self, "orig_module")
148 | self.orig_module = orig_module
149 | self.orig_forward = self.orig_module.forward
150 | self.orig_module.forward = self.forward
151 |
152 | def uninstall(self):
153 | assert hasattr(self, "orig_module")
154 | self.orig_module.forward = self.orig_forward
155 | del self.orig_forward
156 | del self.orig_module
157 |
158 | def append_lora(self, lora_module):
159 | self.lora_modules.append(lora_module)
160 |
161 | def remove_lora(self, lora_module):
162 | self.lora_modules.remove(lora_module)
163 |
164 | def forward(self, x):
165 | if len(self.lora_modules) == 0:
166 | return self.orig_forward(x)
167 | lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0)
168 | return self.orig_forward(x) + lora
169 |
170 |
171 | class LoRAHookInjector(object):
172 | def __init__(self):
173 | super().__init__()
174 | self.hooks = {}
175 | self.device = None
176 | self.dtype = None
177 |
178 | def _get_target_modules(self, root_module, prefix, target_replace_modules):
179 | target_modules = []
180 | for name, module in root_module.named_modules():
181 | if (
182 | module.__class__.__name__ in target_replace_modules
183 | and not "transformer_blocks" in name
184 | ): # to adapt latest diffusers:
185 | for child_name, child_module in module.named_modules():
186 | is_linear = child_module.__class__.__name__ == "Linear"
187 | is_conv2d = child_module.__class__.__name__ == "Conv2d"
188 | if is_linear or is_conv2d:
189 | lora_name = prefix + "." + name + "." + child_name
190 | lora_name = lora_name.replace(".", "_")
191 | target_modules.append((lora_name, child_module))
192 | return target_modules
193 |
194 | def install_hooks(self, pipe):
195 | """Install LoRAHook to the pipe."""
196 | assert len(self.hooks) == 0
197 | text_encoder_targets = self._get_target_modules(
198 | pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]
199 | )
200 | unet_targets = self._get_target_modules(
201 | pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]
202 | )
203 | for name, target_module in text_encoder_targets + unet_targets:
204 | hook = LoRAHook()
205 | hook.install(target_module)
206 | self.hooks[name] = hook
207 |
208 | self.device = pipe.device
209 | self.dtype = pipe.unet.dtype
210 |
211 | def uninstall_hooks(self):
212 | """Uninstall LoRAHook from the pipe."""
213 | for k, v in self.hooks.items():
214 | v.uninstall()
215 | self.hooks = {}
216 |
217 | def apply_lora(self, filename, alpha=1.0):
218 | """Load LoRA weights and apply LoRA to the pipe."""
219 | assert len(self.hooks) != 0
220 | state_dict = safetensors.torch.load_file(filename)
221 | container = LoRAModuleContainer(self.hooks, state_dict, alpha)
222 | container.to(self.device, self.dtype)
223 | return container
224 |
225 | def remove_lora(self, container):
226 | """Remove the individual LoRA from the pipe."""
227 | container.remove_from_hooks(self.hooks)
228 |
229 |
230 | def install_lora_hook(pipe: DiffusionPipeline):
231 | """Install LoRAHook to the pipe."""
232 | assert not hasattr(pipe, "lora_injector")
233 | assert not hasattr(pipe, "apply_lora")
234 | assert not hasattr(pipe, "remove_lora")
235 | injector = LoRAHookInjector()
236 | injector.install_hooks(pipe)
237 | pipe.lora_injector = injector
238 | pipe.apply_lora = injector.apply_lora
239 | pipe.remove_lora = injector.remove_lora
240 |
241 |
242 | def uninstall_lora_hook(pipe: DiffusionPipeline):
243 | """Uninstall LoRAHook from the pipe."""
244 | pipe.lora_injector.uninstall_hooks()
245 | del pipe.lora_injector
246 | del pipe.apply_lora
247 | del pipe.remove_lora
--------------------------------------------------------------------------------
/utils/markdown.py:
--------------------------------------------------------------------------------
1 | import markdown
2 | from ipywidgets import HTML
3 |
4 | def SpoilerLabel(warning:str, spoiler_text:str):
5 | md = f"""
6 |
7 | {warning}
8 | {spoiler_text}
9 |
10 | """
11 | return HTML(markdown.markdown(md))
12 |
13 | def MarkdownLabel(text:str):
14 | return HTML(markdown.markdown(text))
15 |
--------------------------------------------------------------------------------
/utils/preview.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from ipywidgets import Image, HBox, VBox, Button, Layout
4 | from typing import List
5 |
6 | def get_image_previews(paths: List[str], width: int, height: int,
7 | favourite_dir: str = "outputs/favourite") -> HBox:
8 | """Creates a widget preview for every image in paths"""
9 | wgts = []
10 | for filepath in paths:
11 | preview = Image(
12 | value=open(filepath, "rb").read(),
13 | width=width,
14 | height=height,
15 | )
16 | btn = _add_favourite_button(filepath, favourite_dir)
17 | wgt = VBox([preview, btn])
18 | wgts.append(wgt)
19 |
20 | return HBox(wgts)
21 |
22 | def copy_file(file_path: str, to_path: str):
23 | """Copy file at `file_path` to `to_path` directory"""
24 | try_create_dir(to_path)
25 | filename = os.path.basename(file_path)
26 | shutil.copyfile(file_path, os.path.join(to_path, filename))
27 |
28 | def try_create_dir(path: str):
29 | """Create folder at `path` if it doesn't already exist"""
30 | try:
31 | os.makedirs(path)
32 | except FileExistsError: pass
33 |
34 | def _add_favourite_button(filepath: str, to_dir:str) -> Button:
35 | btn = Button(description="Favourite", tooltip=filepath, layout=Layout(width="99%"))
36 | def btn_handler(b):
37 | copy_file(filepath, to_dir)
38 | btn.on_click(btn_handler)
39 | return btn
--------------------------------------------------------------------------------
/utils/ui.py:
--------------------------------------------------------------------------------
1 | from ipywidgets import Button, Layout
2 |
3 | class SwitchButton:
4 | def __init__(self, description_on_switch, description_off_switch, layout,
5 | callback_on, callback_off, start_on=False):
6 | self.description_on_switch = description_on_switch
7 | self.description_off_switch = description_off_switch
8 | self.callback_on = callback_on
9 | self.callback_off = callback_off
10 | self.state = start_on
11 |
12 | self.button = Button(description="", layout=layout)
13 | if start_on: self.__set_on_state()
14 | else: self.__set_off_state()
15 |
16 | def switch_on(self, b):
17 | if self.state is True: return
18 | self.callback_on()
19 | self.__set_on_state()
20 |
21 | def switch_off(self, b):
22 | if self.state is False: return
23 | self.callback_off()
24 | self.__set_off_state()
25 |
26 | def __set_on_state(self):
27 | self.button.on_click(self.switch_on, True)
28 | self.button.on_click(self.switch_off)
29 | self.button.description = self.description_off_switch
30 | self.button.button_style = "success"
31 | self.state = True
32 |
33 | def __set_off_state(self):
34 | self.button.on_click(self.switch_off, True)
35 | self.button.on_click(self.switch_on)
36 | self.button.description = self.description_on_switch
37 | self.button.button_style = "danger"
38 | self.state = False
39 |
40 | @property
41 | def render_element(self):
42 | return self.button
43 |
44 | def render(self):
45 | display(self.render_element)
46 |
47 | def _ipython_display_(self):
48 | self.render()
49 |
--------------------------------------------------------------------------------