├── .gitattributes ├── .github ├── Paperspace.png └── Scenario.png ├── .gitignore ├── AUTOMATIC1111_files ├── CN_models.txt ├── CN_models_XL.txt ├── CN_models_v2.txt ├── Patch ├── blocks.py ├── paths.py └── styles.py ├── Dependencies ├── 1libunwind-dev_1.2.1-9ubuntu0.1_amd64.deb ├── A1111.txt ├── aptdeps.txt ├── aptdeps_311.txt ├── cloudflared-linux-amd64.deb ├── dbdeps.txt ├── git-lfs_2.3.4-1_amd64.deb ├── google-perftools_2.5-2.2ubuntu3_all.deb ├── libc-ares2_1.15.0-1ubuntu0.2_amd64.deb ├── libgoogle-perftools-dev_2.5-2.2ubuntu3_amd64.deb ├── libgoogle-perftools4_2.5-2.2ubuntu3_amd64.deb ├── libtcmalloc-minimal4_2.5-2.2ubuntu3_amd64.deb ├── libzaria2-0_1.35.0-1build1_amd64.deb ├── man-db_2.9.1-1_amd64.deb ├── rename_1.10-1_all.deb ├── rnpd_deps.txt ├── rnpddeps.txt ├── unzip_6.0-25ubuntu1.1_amd64.deb ├── zaria2_1.35.0-1build1_amd64.deb ├── zip_3.0-11build1_amd64.deb └── zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb ├── Dreambooth ├── 1.jpg ├── 2.png ├── 3.JPG ├── 4.jpg ├── blocks.py ├── convertodiffv1.py ├── convertodiffv2-768.py ├── convertodiffv2.py ├── convertosd.py ├── convertosdv2.py ├── det.py ├── hub.py ├── ldm.zip ├── model_index.json ├── refmdlz ├── scheduler_config.json └── smart_crop.py ├── LICENSE ├── README.md ├── fast-DreamBooth.ipynb └── fast_stable_diffusion_AUTOMATIC1111.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | Dependencies/db_deps.tar.zst filter=lfs diff=lfs merge=lfs -text 2 | A1111_dep.tar.zst filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.github/Paperspace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/.github/Paperspace.png -------------------------------------------------------------------------------- /.github/Scenario.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/.github/Scenario.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /AUTOMATIC1111_files/CN_models.txt: -------------------------------------------------------------------------------- 1 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth 2 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1p_sd15_depth.pth 3 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_lineart.pth 4 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_mlsd.pth 5 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_normalbae.pth 6 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_openpose.pth 7 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_scribble.pth 8 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_seg.pth 9 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11e_sd15_ip2p.pth 10 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11e_sd15_shuffle.pth 11 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_inpaint.pth 12 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_softedge.pth 13 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15s2_lineart_anime.pth 14 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth 15 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_keypose-fp16.safetensors 16 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_seg-fp16.safetensors 17 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_sketch-fp16.safetensors 18 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_depth-fp16.safetensors 19 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_canny-fp16.safetensors 20 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_color-fp16.safetensors 21 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_style-fp16.safetensors 22 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_openpose-fp16.safetensors 23 | -------------------------------------------------------------------------------- /AUTOMATIC1111_files/CN_models_XL.txt: -------------------------------------------------------------------------------- 1 | https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/diffusers_xl_canny_mid.safetensors 2 | https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/diffusers_xl_depth_mid.safetensors 3 | https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/sai_xl_sketch_256lora.safetensors 4 | https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/thibaud_xl_openpose_256lora.safetensors 5 | https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/sai_xl_recolor_128lora.safetensors 6 | -------------------------------------------------------------------------------- /AUTOMATIC1111_files/CN_models_v2.txt: -------------------------------------------------------------------------------- 1 | https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_canny.safetensors 2 | https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_depth.safetensors 3 | https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_hed.safetensors 4 | https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_openpose.safetensors 5 | https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_scribble.safetensors 6 | -------------------------------------------------------------------------------- /AUTOMATIC1111_files/Patch: -------------------------------------------------------------------------------- 1 | diff --git a/Makefile.am b/Makefile.am 2 | index f18bf4f..10cc9d6 100755 3 | --- a/Makefile.am 4 | +++ b/Makefile.am 5 | @@ -102,6 +102,7 @@ if HAVE_OBJCOPY_WEAKEN 6 | WEAKEN = $(OBJCOPY) -W malloc -W free -W realloc -W calloc -W cfree \ 7 | -W memalign -W posix_memalign -W valloc -W pvalloc \ 8 | -W malloc_stats -W mallopt -W mallinfo \ 9 | + -W aligned_alloc \ 10 | -W _Znwm -W _ZnwmRKSt9nothrow_t -W _Znam -W _ZnamRKSt9nothrow_t \ 11 | -W _ZdlPv -W _ZdaPv \ 12 | -W __Znwm -W __ZnwmRKSt9nothrow_t -W __Znam -W __ZnamRKSt9nothrow_t \ 13 | diff --git a/src/libc_override_gcc_and_weak.h b/src/libc_override_gcc_and_weak.h 14 | index ecb66ec..1f19e01 100644 15 | --- a/src/libc_override_gcc_and_weak.h 16 | +++ b/src/libc_override_gcc_and_weak.h 17 | @@ -143,6 +143,7 @@ extern "C" { 18 | void* calloc(size_t n, size_t size) __THROW ALIAS(tc_calloc); 19 | void cfree(void* ptr) __THROW ALIAS(tc_cfree); 20 | void* memalign(size_t align, size_t s) __THROW ALIAS(tc_memalign); 21 | + void* aligned_alloc(size_t align, size_t s) __THROW ALIAS(tc_memalign); 22 | void* valloc(size_t size) __THROW ALIAS(tc_valloc); 23 | void* pvalloc(size_t size) __THROW ALIAS(tc_pvalloc); 24 | int posix_memalign(void** r, size_t a, size_t s) __THROW 25 | diff --git a/src/libc_override_redefine.h b/src/libc_override_redefine.h 26 | index 72679ef..89ad584 100644 27 | --- a/src/libc_override_redefine.h 28 | +++ b/src/libc_override_redefine.h 29 | @@ -71,6 +71,7 @@ extern "C" { 30 | void* calloc(size_t n, size_t s) { return tc_calloc(n, s); } 31 | void cfree(void* p) { tc_cfree(p); } 32 | void* memalign(size_t a, size_t s) { return tc_memalign(a, s); } 33 | + void* aligned_alloc(size_t a, size_t s) { return tc_memalign(a, s); } 34 | void* valloc(size_t s) { return tc_valloc(s); } 35 | void* pvalloc(size_t s) { return tc_pvalloc(s); } 36 | int posix_memalign(void** r, size_t a, size_t s) { 37 | -------------------------------------------------------------------------------- /AUTOMATIC1111_files/paths.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir 5 | import modules.safe 6 | 7 | # data_path = cmd_opts_pre.data 8 | sys.path.insert(0, script_path) 9 | 10 | # search for directory of stable diffusion in following places 11 | sd_path = None 12 | possible_sd_paths = [os.path.join(script_path, '/content/gdrive/MyDrive/sd/stablediffusion'), '.', os.path.dirname(script_path)] 13 | for possible_sd_path in possible_sd_paths: 14 | if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): 15 | sd_path = os.path.abspath(possible_sd_path) 16 | break 17 | 18 | assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) 19 | 20 | path_dirs = [ 21 | (sd_path, 'ldm', 'Stable Diffusion', []), 22 | (os.path.join(sd_path, 'src/taming-transformers'), 'taming', 'Taming Transformers', []), 23 | (os.path.join(sd_path, 'src/codeformer'), 'inference_codeformer.py', 'CodeFormer', []), 24 | (os.path.join(sd_path, 'src/blip'), 'models/blip.py', 'BLIP', []), 25 | (os.path.join(sd_path, 'src/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), 26 | ] 27 | 28 | paths = {} 29 | 30 | for d, must_exist, what, options in path_dirs: 31 | must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist)) 32 | if not os.path.exists(must_exist_path): 33 | print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr) 34 | else: 35 | d = os.path.abspath(d) 36 | if "atstart" in options: 37 | sys.path.insert(0, d) 38 | else: 39 | sys.path.append(d) 40 | paths[what] = d 41 | 42 | class Prioritize: 43 | def __init__(self, name): 44 | self.name = name 45 | self.path = None 46 | 47 | def __enter__(self): 48 | self.path = sys.path.copy() 49 | sys.path = [paths[self.name]] + sys.path 50 | 51 | def __exit__(self, exc_type, exc_val, exc_tb): 52 | sys.path = self.path 53 | self.path = None 54 | -------------------------------------------------------------------------------- /AUTOMATIC1111_files/styles.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from modules import errors 3 | import csv 4 | import os 5 | import typing 6 | import shutil 7 | 8 | 9 | class PromptStyle(typing.NamedTuple): 10 | name: str 11 | prompt: str 12 | negative_prompt: str 13 | path: str 14 | 15 | 16 | def merge_prompts(style_prompt: str, prompt: str) -> str: 17 | if "{prompt}" in style_prompt: 18 | res = style_prompt.replace("{prompt}", prompt) 19 | else: 20 | parts = filter(None, (prompt.strip(), style_prompt.strip())) 21 | res = ", ".join(parts) 22 | 23 | return res 24 | 25 | 26 | def apply_styles_to_prompt(prompt, styles): 27 | for style in styles: 28 | prompt = merge_prompts(style, prompt) 29 | 30 | return prompt 31 | 32 | 33 | def extract_style_text_from_prompt(style_text, prompt): 34 | """This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt. 35 | 36 | extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg") 37 | extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg") 38 | extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg") 39 | """ 40 | 41 | stripped_prompt = prompt.strip() 42 | stripped_style_text = style_text.strip() 43 | 44 | if "{prompt}" in stripped_style_text: 45 | left, right = stripped_style_text.split("{prompt}", 2) 46 | if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): 47 | prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)] 48 | return True, prompt 49 | else: 50 | if stripped_prompt.endswith(stripped_style_text): 51 | prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)] 52 | 53 | if prompt.endswith(', '): 54 | prompt = prompt[:-2] 55 | 56 | return True, prompt 57 | 58 | return False, prompt 59 | 60 | 61 | def extract_original_prompts(style: PromptStyle, prompt, negative_prompt): 62 | """ 63 | Takes a style and compares it to the prompt and negative prompt. If the style 64 | matches, returns True plus the prompt and negative prompt with the style text 65 | removed. Otherwise, returns False with the original prompt and negative prompt. 66 | """ 67 | if not style.prompt and not style.negative_prompt: 68 | return False, prompt, negative_prompt 69 | 70 | match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt) 71 | if not match_positive: 72 | return False, prompt, negative_prompt 73 | 74 | match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt) 75 | if not match_negative: 76 | return False, prompt, negative_prompt 77 | 78 | return True, extracted_positive, extracted_negative 79 | 80 | 81 | class StyleDatabase: 82 | def __init__(self, paths: list[str]): 83 | self.no_style = PromptStyle("None", "", "", None) 84 | self.styles = {} 85 | self.paths = paths 86 | self.all_styles_files: list[Path] = [] 87 | 88 | folder, file = os.path.split(self.paths[0]) 89 | if '*' in file or '?' in file: 90 | # if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path 91 | self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv'))) 92 | self.paths.insert(0, self.default_path) 93 | else: 94 | self.default_path = Path(self.paths[0]) 95 | 96 | self.prompt_fields = [field for field in PromptStyle._fields if field != "path"] 97 | 98 | self.reload() 99 | 100 | def reload(self): 101 | """ 102 | Clears the style database and reloads the styles from the CSV file(s) 103 | matching the path used to initialize the database. 104 | """ 105 | self.styles.clear() 106 | 107 | # scans for all styles files 108 | all_styles_files = [] 109 | for pattern in self.paths: 110 | folder, file = os.path.split(pattern) 111 | if '*' in file or '?' in file: 112 | found_files = Path(folder).glob(file) 113 | [all_styles_files.append(file) for file in found_files] 114 | else: 115 | # if os.path.exists(pattern): 116 | all_styles_files.append(Path(pattern)) 117 | 118 | # Remove any duplicate entries 119 | seen = set() 120 | self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))] 121 | 122 | for styles_file in self.all_styles_files: 123 | if len(all_styles_files) > 1: 124 | # add divider when more than styles file 125 | # '---------------- STYLES ----------------' 126 | divider = f' {styles_file.stem.upper()} '.center(40, '-') 127 | self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save") 128 | if styles_file.is_file(): 129 | self.load_from_csv(styles_file) 130 | 131 | def load_from_csv(self, path: str): 132 | try: 133 | with open(path, "r", encoding="utf-8-sig", newline="") as file: 134 | reader = csv.DictReader(file, skipinitialspace=True) 135 | for row in reader: 136 | # Ignore empty rows or rows starting with a comment 137 | if not row or row["name"].startswith("#"): 138 | continue 139 | # Support loading old CSV format with "name, text"-columns 140 | prompt = row["prompt"] if "prompt" in row else row["text"] 141 | negative_prompt = row.get("negative_prompt", "") 142 | # Add style to database 143 | self.styles[row["name"]] = PromptStyle( 144 | row["name"], prompt, negative_prompt, str(path) 145 | ) 146 | except Exception: 147 | errors.report(f'Error loading styles from {path}: ', exc_info=True) 148 | 149 | def get_style_paths(self) -> set: 150 | """Returns a set of all distinct paths of files that styles are loaded from.""" 151 | # Update any styles without a path to the default path 152 | for style in list(self.styles.values()): 153 | if not style.path: 154 | self.styles[style.name] = style._replace(path=str(self.default_path)) 155 | 156 | # Create a list of all distinct paths, including the default path 157 | style_paths = set() 158 | style_paths.add(str(self.default_path)) 159 | for _, style in self.styles.items(): 160 | if style.path: 161 | style_paths.add(style.path) 162 | 163 | # Remove any paths for styles that are just list dividers 164 | style_paths.discard("do_not_save") 165 | 166 | return style_paths 167 | 168 | def get_style_prompts(self, styles): 169 | return [self.styles.get(x, self.no_style).prompt for x in styles] 170 | 171 | def get_negative_style_prompts(self, styles): 172 | return [self.styles.get(x, self.no_style).negative_prompt for x in styles] 173 | 174 | def apply_styles_to_prompt(self, prompt, styles): 175 | return apply_styles_to_prompt( 176 | prompt, [self.styles.get(x, self.no_style).prompt for x in styles] 177 | ) 178 | 179 | def apply_negative_styles_to_prompt(self, prompt, styles): 180 | return apply_styles_to_prompt( 181 | prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles] 182 | ) 183 | 184 | def save_styles(self, path: str = None) -> None: 185 | # The path argument is deprecated, but kept for backwards compatibility 186 | 187 | style_paths = self.get_style_paths() 188 | 189 | csv_names = [os.path.split(path)[1].lower() for path in style_paths] 190 | 191 | for style_path in style_paths: 192 | # Always keep a backup file around 193 | if os.path.exists(style_path): 194 | shutil.copy(style_path, f"{style_path}.bak") 195 | 196 | # Write the styles to the CSV file 197 | with open(style_path, "w", encoding="utf-8-sig", newline="") as file: 198 | writer = csv.DictWriter(file, fieldnames=self.prompt_fields) 199 | writer.writeheader() 200 | for style in (s for s in self.styles.values() if s.path == style_path): 201 | # Skip style list dividers, e.g. "STYLES.CSV" 202 | if style.name.lower().strip("# ") in csv_names: 203 | continue 204 | # Write style fields, ignoring the path field 205 | writer.writerow( 206 | {k: v for k, v in style._asdict().items() if k != "path"} 207 | ) 208 | 209 | def extract_styles_from_prompt(self, prompt, negative_prompt): 210 | extracted = [] 211 | 212 | applicable_styles = list(self.styles.values()) 213 | 214 | while True: 215 | found_style = None 216 | 217 | for style in applicable_styles: 218 | is_match, new_prompt, new_neg_prompt = extract_original_prompts( 219 | style, prompt, negative_prompt 220 | ) 221 | if is_match: 222 | found_style = style 223 | prompt = new_prompt 224 | negative_prompt = new_neg_prompt 225 | break 226 | 227 | if not found_style: 228 | break 229 | 230 | applicable_styles.remove(found_style) 231 | extracted.append(found_style.name) 232 | 233 | return list(reversed(extracted)), prompt, negative_prompt 234 | -------------------------------------------------------------------------------- /Dependencies/1libunwind-dev_1.2.1-9ubuntu0.1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/1libunwind-dev_1.2.1-9ubuntu0.1_amd64.deb -------------------------------------------------------------------------------- /Dependencies/A1111.txt: -------------------------------------------------------------------------------- 1 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb 2 | https://huggingface.co/TheLastBen/dependencies/resolve/main/sd_mrep.tar.zst 3 | https://huggingface.co/TheLastBen/dependencies/resolve/main/gcolabdeps.tar.zst 4 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/cloudflared-linux-amd64.deb 5 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/libc-ares2_1.15.0-1ubuntu0.2_amd64.deb 6 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/libzaria2-0_1.35.0-1build1_amd64.deb 7 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/man-db_2.9.1-1_amd64.deb 8 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zaria2_1.35.0-1build1_amd64.deb 9 | -------------------------------------------------------------------------------- /Dependencies/aptdeps.txt: -------------------------------------------------------------------------------- 1 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/git-lfs_2.3.4-1_amd64.deb 2 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/rename_1.10-1_all.deb 3 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb 4 | -------------------------------------------------------------------------------- /Dependencies/aptdeps_311.txt: -------------------------------------------------------------------------------- 1 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb 2 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/rename_1.10-1_all.deb 3 | -------------------------------------------------------------------------------- /Dependencies/cloudflared-linux-amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/cloudflared-linux-amd64.deb -------------------------------------------------------------------------------- /Dependencies/dbdeps.txt: -------------------------------------------------------------------------------- 1 | https://huggingface.co/TheLastBen/dependencies/resolve/main/gcolabdeps.tar.zst 2 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb 3 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/libc-ares2_1.15.0-1ubuntu0.2_amd64.deb 4 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/libzaria2-0_1.35.0-1build1_amd64.deb 5 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/man-db_2.9.1-1_amd64.deb 6 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zaria2_1.35.0-1build1_amd64.deb 7 | -------------------------------------------------------------------------------- /Dependencies/git-lfs_2.3.4-1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/git-lfs_2.3.4-1_amd64.deb -------------------------------------------------------------------------------- /Dependencies/google-perftools_2.5-2.2ubuntu3_all.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/google-perftools_2.5-2.2ubuntu3_all.deb -------------------------------------------------------------------------------- /Dependencies/libc-ares2_1.15.0-1ubuntu0.2_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/libc-ares2_1.15.0-1ubuntu0.2_amd64.deb -------------------------------------------------------------------------------- /Dependencies/libgoogle-perftools-dev_2.5-2.2ubuntu3_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/libgoogle-perftools-dev_2.5-2.2ubuntu3_amd64.deb -------------------------------------------------------------------------------- /Dependencies/libgoogle-perftools4_2.5-2.2ubuntu3_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/libgoogle-perftools4_2.5-2.2ubuntu3_amd64.deb -------------------------------------------------------------------------------- /Dependencies/libtcmalloc-minimal4_2.5-2.2ubuntu3_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/libtcmalloc-minimal4_2.5-2.2ubuntu3_amd64.deb -------------------------------------------------------------------------------- /Dependencies/libzaria2-0_1.35.0-1build1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/libzaria2-0_1.35.0-1build1_amd64.deb -------------------------------------------------------------------------------- /Dependencies/man-db_2.9.1-1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/man-db_2.9.1-1_amd64.deb -------------------------------------------------------------------------------- /Dependencies/rename_1.10-1_all.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/rename_1.10-1_all.deb -------------------------------------------------------------------------------- /Dependencies/rnpd_deps.txt: -------------------------------------------------------------------------------- 1 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb 2 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/git-lfs_2.3.4-1_amd64.deb 3 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/rename_1.10-1_all.deb 4 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zip_3.0-11build1_amd64.deb 5 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/unzip_6.0-25ubuntu1.1_amd64.deb 6 | https://huggingface.co/TheLastBen/dependencies/resolve/main/rnpd-310.tar.zst 7 | -------------------------------------------------------------------------------- /Dependencies/rnpddeps.txt: -------------------------------------------------------------------------------- 1 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb 2 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/git-lfs_2.3.4-1_amd64.deb 3 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/rename_1.10-1_all.deb 4 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zip_3.0-11build1_amd64.deb 5 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/unzip_6.0-25ubuntu1.1_amd64.deb 6 | https://huggingface.co/TheLastBen/dependencies/resolve/main/rnpddeps.tar.zst 7 | -------------------------------------------------------------------------------- /Dependencies/unzip_6.0-25ubuntu1.1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/unzip_6.0-25ubuntu1.1_amd64.deb -------------------------------------------------------------------------------- /Dependencies/zaria2_1.35.0-1build1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/zaria2_1.35.0-1build1_amd64.deb -------------------------------------------------------------------------------- /Dependencies/zip_3.0-11build1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/zip_3.0-11build1_amd64.deb -------------------------------------------------------------------------------- /Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb -------------------------------------------------------------------------------- /Dreambooth/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dreambooth/1.jpg -------------------------------------------------------------------------------- /Dreambooth/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dreambooth/2.png -------------------------------------------------------------------------------- /Dreambooth/3.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dreambooth/3.JPG -------------------------------------------------------------------------------- /Dreambooth/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dreambooth/4.jpg -------------------------------------------------------------------------------- /Dreambooth/convertodiffv1.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig 5 | from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel 6 | 7 | 8 | 9 | # DiffUsers版StableDiffusionのモデルパラメータ 10 | NUM_TRAIN_TIMESTEPS = 1000 11 | BETA_START = 0.00085 12 | BETA_END = 0.0120 13 | 14 | UNET_PARAMS_MODEL_CHANNELS = 320 15 | UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] 16 | UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] 17 | UNET_PARAMS_IMAGE_SIZE = 64 18 | UNET_PARAMS_IN_CHANNELS = 4 19 | UNET_PARAMS_OUT_CHANNELS = 4 20 | UNET_PARAMS_NUM_RES_BLOCKS = 2 21 | UNET_PARAMS_CONTEXT_DIM = 768 22 | UNET_PARAMS_NUM_HEADS = 8 23 | 24 | VAE_PARAMS_Z_CHANNELS = 4 25 | VAE_PARAMS_RESOLUTION = 512 26 | VAE_PARAMS_IN_CHANNELS = 3 27 | VAE_PARAMS_OUT_CH = 3 28 | VAE_PARAMS_CH = 128 29 | VAE_PARAMS_CH_MULT = [1, 2, 4, 4] 30 | VAE_PARAMS_NUM_RES_BLOCKS = 2 31 | 32 | # V2 33 | V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] 34 | V2_UNET_PARAMS_CONTEXT_DIM = 1024 35 | 36 | 37 | # region StableDiffusion->Diffusersの変換コード 38 | # convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0) 39 | 40 | 41 | def shave_segments(path, n_shave_prefix_segments=1): 42 | """ 43 | Removes segments. Positive values shave the first segments, negative shave the last segments. 44 | """ 45 | if n_shave_prefix_segments >= 0: 46 | return ".".join(path.split(".")[n_shave_prefix_segments:]) 47 | else: 48 | return ".".join(path.split(".")[:n_shave_prefix_segments]) 49 | 50 | 51 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): 52 | """ 53 | Updates paths inside resnets to the new naming scheme (local renaming) 54 | """ 55 | mapping = [] 56 | for old_item in old_list: 57 | new_item = old_item.replace("in_layers.0", "norm1") 58 | new_item = new_item.replace("in_layers.2", "conv1") 59 | 60 | new_item = new_item.replace("out_layers.0", "norm2") 61 | new_item = new_item.replace("out_layers.3", "conv2") 62 | 63 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") 64 | new_item = new_item.replace("skip_connection", "conv_shortcut") 65 | 66 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 67 | 68 | mapping.append({"old": old_item, "new": new_item}) 69 | 70 | return mapping 71 | 72 | 73 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): 74 | """ 75 | Updates paths inside resnets to the new naming scheme (local renaming) 76 | """ 77 | mapping = [] 78 | for old_item in old_list: 79 | new_item = old_item 80 | 81 | new_item = new_item.replace("nin_shortcut", "conv_shortcut") 82 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 83 | 84 | mapping.append({"old": old_item, "new": new_item}) 85 | 86 | return mapping 87 | 88 | 89 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): 90 | """ 91 | Updates paths inside attentions to the new naming scheme (local renaming) 92 | """ 93 | mapping = [] 94 | for old_item in old_list: 95 | new_item = old_item 96 | 97 | # new_item = new_item.replace('norm.weight', 'group_norm.weight') 98 | # new_item = new_item.replace('norm.bias', 'group_norm.bias') 99 | 100 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') 101 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') 102 | 103 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 104 | 105 | mapping.append({"old": old_item, "new": new_item}) 106 | 107 | return mapping 108 | 109 | 110 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): 111 | """ 112 | Updates paths inside attentions to the new naming scheme (local renaming) 113 | """ 114 | mapping = [] 115 | for old_item in old_list: 116 | new_item = old_item 117 | 118 | new_item = new_item.replace("norm.weight", "group_norm.weight") 119 | new_item = new_item.replace("norm.bias", "group_norm.bias") 120 | 121 | new_item = new_item.replace("q.weight", "query.weight") 122 | new_item = new_item.replace("q.bias", "query.bias") 123 | 124 | new_item = new_item.replace("k.weight", "key.weight") 125 | new_item = new_item.replace("k.bias", "key.bias") 126 | 127 | new_item = new_item.replace("v.weight", "value.weight") 128 | new_item = new_item.replace("v.bias", "value.bias") 129 | 130 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight") 131 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias") 132 | 133 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 134 | 135 | mapping.append({"old": old_item, "new": new_item}) 136 | 137 | return mapping 138 | 139 | 140 | def assign_to_checkpoint( 141 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None 142 | ): 143 | """ 144 | This does the final conversion step: take locally converted weights and apply a global renaming 145 | to them. It splits attention layers, and takes into account additional replacements 146 | that may arise. 147 | 148 | Assigns the weights to the new checkpoint. 149 | """ 150 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." 151 | 152 | # Splits the attention layers into three variables. 153 | if attention_paths_to_split is not None: 154 | for path, path_map in attention_paths_to_split.items(): 155 | old_tensor = old_checkpoint[path] 156 | channels = old_tensor.shape[0] // 3 157 | 158 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) 159 | 160 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 161 | 162 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) 163 | query, key, value = old_tensor.split(channels // num_heads, dim=1) 164 | 165 | checkpoint[path_map["query"]] = query.reshape(target_shape) 166 | checkpoint[path_map["key"]] = key.reshape(target_shape) 167 | checkpoint[path_map["value"]] = value.reshape(target_shape) 168 | 169 | for path in paths: 170 | new_path = path["new"] 171 | 172 | # These have already been assigned 173 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: 174 | continue 175 | 176 | # Global renaming happens here 177 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") 178 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") 179 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") 180 | 181 | if additional_replacements is not None: 182 | for replacement in additional_replacements: 183 | new_path = new_path.replace(replacement["old"], replacement["new"]) 184 | 185 | # proj_attn.weight has to be converted from conv 1D to linear 186 | if "proj_attn.weight" in new_path: 187 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] 188 | else: 189 | checkpoint[new_path] = old_checkpoint[path["old"]] 190 | 191 | 192 | def conv_attn_to_linear(checkpoint): 193 | keys = list(checkpoint.keys()) 194 | attn_keys = ["query.weight", "key.weight", "value.weight"] 195 | for key in keys: 196 | if ".".join(key.split(".")[-2:]) in attn_keys: 197 | if checkpoint[key].ndim > 2: 198 | checkpoint[key] = checkpoint[key][:, :, 0, 0] 199 | elif "proj_attn.weight" in key: 200 | if checkpoint[key].ndim > 2: 201 | checkpoint[key] = checkpoint[key][:, :, 0] 202 | 203 | 204 | def linear_transformer_to_conv(checkpoint): 205 | keys = list(checkpoint.keys()) 206 | tf_keys = ["proj_in.weight", "proj_out.weight"] 207 | for key in keys: 208 | if ".".join(key.split(".")[-2:]) in tf_keys: 209 | if checkpoint[key].ndim == 2: 210 | checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) 211 | 212 | 213 | def convert_ldm_unet_checkpoint(v2, checkpoint, config): 214 | """ 215 | Takes a state dict and a config, and returns a converted checkpoint. 216 | """ 217 | 218 | # extract state_dict for UNet 219 | unet_state_dict = {} 220 | unet_key = "model.diffusion_model." 221 | keys = list(checkpoint.keys()) 222 | for key in keys: 223 | if key.startswith(unet_key): 224 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) 225 | 226 | new_checkpoint = {} 227 | 228 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] 229 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] 230 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] 231 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] 232 | 233 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] 234 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] 235 | 236 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] 237 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] 238 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] 239 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] 240 | 241 | # Retrieves the keys for the input blocks only 242 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) 243 | input_blocks = { 244 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] 245 | for layer_id in range(num_input_blocks) 246 | } 247 | 248 | # Retrieves the keys for the middle blocks only 249 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) 250 | middle_blocks = { 251 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] 252 | for layer_id in range(num_middle_blocks) 253 | } 254 | 255 | # Retrieves the keys for the output blocks only 256 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) 257 | output_blocks = { 258 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] 259 | for layer_id in range(num_output_blocks) 260 | } 261 | 262 | for i in range(1, num_input_blocks): 263 | block_id = (i - 1) // (config["layers_per_block"] + 1) 264 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) 265 | 266 | resnets = [ 267 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key 268 | ] 269 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] 270 | 271 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict: 272 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( 273 | f"input_blocks.{i}.0.op.weight" 274 | ) 275 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( 276 | f"input_blocks.{i}.0.op.bias" 277 | ) 278 | 279 | paths = renew_resnet_paths(resnets) 280 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} 281 | assign_to_checkpoint( 282 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 283 | ) 284 | 285 | if len(attentions): 286 | paths = renew_attention_paths(attentions) 287 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} 288 | assign_to_checkpoint( 289 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 290 | ) 291 | 292 | resnet_0 = middle_blocks[0] 293 | attentions = middle_blocks[1] 294 | resnet_1 = middle_blocks[2] 295 | 296 | resnet_0_paths = renew_resnet_paths(resnet_0) 297 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) 298 | 299 | resnet_1_paths = renew_resnet_paths(resnet_1) 300 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) 301 | 302 | attentions_paths = renew_attention_paths(attentions) 303 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} 304 | assign_to_checkpoint( 305 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 306 | ) 307 | 308 | for i in range(num_output_blocks): 309 | block_id = i // (config["layers_per_block"] + 1) 310 | layer_in_block_id = i % (config["layers_per_block"] + 1) 311 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] 312 | output_block_list = {} 313 | 314 | for layer in output_block_layers: 315 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) 316 | if layer_id in output_block_list: 317 | output_block_list[layer_id].append(layer_name) 318 | else: 319 | output_block_list[layer_id] = [layer_name] 320 | 321 | if len(output_block_list) > 1: 322 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] 323 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] 324 | 325 | resnet_0_paths = renew_resnet_paths(resnets) 326 | paths = renew_resnet_paths(resnets) 327 | 328 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} 329 | assign_to_checkpoint( 330 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 331 | ) 332 | 333 | if ["conv.weight", "conv.bias"] in output_block_list.values(): 334 | index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) 335 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ 336 | f"output_blocks.{i}.{index}.conv.weight" 337 | ] 338 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ 339 | f"output_blocks.{i}.{index}.conv.bias" 340 | ] 341 | 342 | # Clear attentions as they have been attributed above. 343 | if len(attentions) == 2: 344 | attentions = [] 345 | 346 | if len(attentions): 347 | paths = renew_attention_paths(attentions) 348 | meta_path = { 349 | "old": f"output_blocks.{i}.1", 350 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", 351 | } 352 | assign_to_checkpoint( 353 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 354 | ) 355 | else: 356 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) 357 | for path in resnet_0_paths: 358 | old_path = ".".join(["output_blocks", str(i), path["old"]]) 359 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) 360 | 361 | new_checkpoint[new_path] = unet_state_dict[old_path] 362 | 363 | # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する 364 | if v2: 365 | linear_transformer_to_conv(new_checkpoint) 366 | 367 | return new_checkpoint 368 | 369 | 370 | def convert_ldm_vae_checkpoint(checkpoint, config): 371 | # extract state dict for VAE 372 | vae_state_dict = {} 373 | vae_key = "first_stage_model." 374 | keys = list(checkpoint.keys()) 375 | for key in keys: 376 | if key.startswith(vae_key): 377 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) 378 | # if len(vae_state_dict) == 0: 379 | # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict 380 | # vae_state_dict = checkpoint 381 | 382 | new_checkpoint = {} 383 | 384 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] 385 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] 386 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] 387 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] 388 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] 389 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] 390 | 391 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] 392 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] 393 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] 394 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] 395 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] 396 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] 397 | 398 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] 399 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] 400 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] 401 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] 402 | 403 | # Retrieves the keys for the encoder down blocks only 404 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) 405 | down_blocks = { 406 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) 407 | } 408 | 409 | # Retrieves the keys for the decoder up blocks only 410 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) 411 | up_blocks = { 412 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) 413 | } 414 | 415 | for i in range(num_down_blocks): 416 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] 417 | 418 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: 419 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( 420 | f"encoder.down.{i}.downsample.conv.weight" 421 | ) 422 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( 423 | f"encoder.down.{i}.downsample.conv.bias" 424 | ) 425 | 426 | paths = renew_vae_resnet_paths(resnets) 427 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} 428 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 429 | 430 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] 431 | num_mid_res_blocks = 2 432 | for i in range(1, num_mid_res_blocks + 1): 433 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] 434 | 435 | paths = renew_vae_resnet_paths(resnets) 436 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 437 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 438 | 439 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] 440 | paths = renew_vae_attention_paths(mid_attentions) 441 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 442 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 443 | conv_attn_to_linear(new_checkpoint) 444 | 445 | for i in range(num_up_blocks): 446 | block_id = num_up_blocks - 1 - i 447 | resnets = [ 448 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key 449 | ] 450 | 451 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: 452 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ 453 | f"decoder.up.{block_id}.upsample.conv.weight" 454 | ] 455 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ 456 | f"decoder.up.{block_id}.upsample.conv.bias" 457 | ] 458 | 459 | paths = renew_vae_resnet_paths(resnets) 460 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} 461 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 462 | 463 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] 464 | num_mid_res_blocks = 2 465 | for i in range(1, num_mid_res_blocks + 1): 466 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] 467 | 468 | paths = renew_vae_resnet_paths(resnets) 469 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 470 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 471 | 472 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] 473 | paths = renew_vae_attention_paths(mid_attentions) 474 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 475 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 476 | conv_attn_to_linear(new_checkpoint) 477 | return new_checkpoint 478 | 479 | 480 | def create_unet_diffusers_config(v2): 481 | """ 482 | Creates a config for the diffusers based on the config of the LDM model. 483 | """ 484 | # unet_params = original_config.model.params.unet_config.params 485 | 486 | block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] 487 | 488 | down_block_types = [] 489 | resolution = 1 490 | for i in range(len(block_out_channels)): 491 | block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" 492 | down_block_types.append(block_type) 493 | if i != len(block_out_channels) - 1: 494 | resolution *= 2 495 | 496 | up_block_types = [] 497 | for i in range(len(block_out_channels)): 498 | block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" 499 | up_block_types.append(block_type) 500 | resolution //= 2 501 | 502 | config = dict( 503 | sample_size=UNET_PARAMS_IMAGE_SIZE, 504 | in_channels=UNET_PARAMS_IN_CHANNELS, 505 | out_channels=UNET_PARAMS_OUT_CHANNELS, 506 | down_block_types=tuple(down_block_types), 507 | up_block_types=tuple(up_block_types), 508 | block_out_channels=tuple(block_out_channels), 509 | layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, 510 | cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, 511 | attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, 512 | ) 513 | 514 | return config 515 | 516 | 517 | def create_vae_diffusers_config(): 518 | """ 519 | Creates a config for the diffusers based on the config of the LDM model. 520 | """ 521 | # vae_params = original_config.model.params.first_stage_config.params.ddconfig 522 | # _ = original_config.model.params.first_stage_config.params.embed_dim 523 | block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] 524 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) 525 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) 526 | 527 | config = dict( 528 | sample_size=VAE_PARAMS_RESOLUTION, 529 | in_channels=VAE_PARAMS_IN_CHANNELS, 530 | out_channels=VAE_PARAMS_OUT_CH, 531 | down_block_types=tuple(down_block_types), 532 | up_block_types=tuple(up_block_types), 533 | block_out_channels=tuple(block_out_channels), 534 | latent_channels=VAE_PARAMS_Z_CHANNELS, 535 | layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, 536 | ) 537 | return config 538 | 539 | 540 | def convert_ldm_clip_checkpoint_v1(checkpoint): 541 | keys = list(checkpoint.keys()) 542 | text_model_dict = {} 543 | for key in keys: 544 | if key.startswith("cond_stage_model.transformer"): 545 | text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] 546 | return text_model_dict 547 | 548 | 549 | def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): 550 | # 嫌になるくらい違うぞ! 551 | def convert_key(key): 552 | if not key.startswith("cond_stage_model"): 553 | return None 554 | 555 | # common conversion 556 | key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") 557 | key = key.replace("cond_stage_model.model.", "text_model.") 558 | 559 | if "resblocks" in key: 560 | # resblocks conversion 561 | key = key.replace(".resblocks.", ".layers.") 562 | if ".ln_" in key: 563 | key = key.replace(".ln_", ".layer_norm") 564 | elif ".mlp." in key: 565 | key = key.replace(".c_fc.", ".fc1.") 566 | key = key.replace(".c_proj.", ".fc2.") 567 | elif '.attn.out_proj' in key: 568 | key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") 569 | elif '.attn.in_proj' in key: 570 | key = None # 特殊なので後で処理する 571 | else: 572 | raise ValueError(f"unexpected key in SD: {key}") 573 | elif '.positional_embedding' in key: 574 | key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") 575 | elif '.text_projection' in key: 576 | key = None # 使われない??? 577 | elif '.logit_scale' in key: 578 | key = None # 使われない??? 579 | elif '.token_embedding' in key: 580 | key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") 581 | elif '.ln_final' in key: 582 | key = key.replace(".ln_final", ".final_layer_norm") 583 | return key 584 | 585 | keys = list(checkpoint.keys()) 586 | new_sd = {} 587 | for key in keys: 588 | # remove resblocks 23 589 | if '.resblocks.23.' in key: 590 | continue 591 | new_key = convert_key(key) 592 | if new_key is None: 593 | continue 594 | new_sd[new_key] = checkpoint[key] 595 | 596 | # attnの変換 597 | for key in keys: 598 | if '.resblocks.23.' in key: 599 | continue 600 | if '.resblocks' in key and '.attn.in_proj_' in key: 601 | # 三つに分割 602 | values = torch.chunk(checkpoint[key], 3) 603 | 604 | key_suffix = ".weight" if "weight" in key else ".bias" 605 | key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") 606 | key_pfx = key_pfx.replace("_weight", "") 607 | key_pfx = key_pfx.replace("_bias", "") 608 | key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") 609 | new_sd[key_pfx + "q_proj" + key_suffix] = values[0] 610 | new_sd[key_pfx + "k_proj" + key_suffix] = values[1] 611 | new_sd[key_pfx + "v_proj" + key_suffix] = values[2] 612 | 613 | # position_idsの追加 614 | new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64) 615 | return new_sd 616 | 617 | # endregion 618 | 619 | 620 | # region Diffusers->StableDiffusion の変換コード 621 | # convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0) 622 | 623 | def conv_transformer_to_linear(checkpoint): 624 | keys = list(checkpoint.keys()) 625 | tf_keys = ["proj_in.weight", "proj_out.weight"] 626 | for key in keys: 627 | if ".".join(key.split(".")[-2:]) in tf_keys: 628 | if checkpoint[key].ndim > 2: 629 | checkpoint[key] = checkpoint[key][:, :, 0, 0] 630 | 631 | 632 | def convert_unet_state_dict_to_sd(v2, unet_state_dict): 633 | unet_conversion_map = [ 634 | # (stable-diffusion, HF Diffusers) 635 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), 636 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), 637 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), 638 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), 639 | ("input_blocks.0.0.weight", "conv_in.weight"), 640 | ("input_blocks.0.0.bias", "conv_in.bias"), 641 | ("out.0.weight", "conv_norm_out.weight"), 642 | ("out.0.bias", "conv_norm_out.bias"), 643 | ("out.2.weight", "conv_out.weight"), 644 | ("out.2.bias", "conv_out.bias"), 645 | ] 646 | 647 | unet_conversion_map_resnet = [ 648 | # (stable-diffusion, HF Diffusers) 649 | ("in_layers.0", "norm1"), 650 | ("in_layers.2", "conv1"), 651 | ("out_layers.0", "norm2"), 652 | ("out_layers.3", "conv2"), 653 | ("emb_layers.1", "time_emb_proj"), 654 | ("skip_connection", "conv_shortcut"), 655 | ] 656 | 657 | unet_conversion_map_layer = [] 658 | for i in range(4): 659 | # loop over downblocks/upblocks 660 | 661 | for j in range(2): 662 | # loop over resnets/attentions for downblocks 663 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 664 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 665 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 666 | 667 | if i < 3: 668 | # no attention layers in down_blocks.3 669 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 670 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 671 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 672 | 673 | for j in range(3): 674 | # loop over resnets/attentions for upblocks 675 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 676 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 677 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 678 | 679 | if i > 0: 680 | # no attention layers in up_blocks.0 681 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 682 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 683 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 684 | 685 | if i < 3: 686 | # no downsample in down_blocks.3 687 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 688 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 689 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 690 | 691 | # no upsample in up_blocks.3 692 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 693 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." 694 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 695 | 696 | hf_mid_atn_prefix = "mid_block.attentions.0." 697 | sd_mid_atn_prefix = "middle_block.1." 698 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 699 | 700 | for j in range(2): 701 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 702 | sd_mid_res_prefix = f"middle_block.{2*j}." 703 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 704 | 705 | # buyer beware: this is a *brittle* function, 706 | # and correct output requires that all of these pieces interact in 707 | # the exact order in which I have arranged them. 708 | mapping = {k: k for k in unet_state_dict.keys()} 709 | for sd_name, hf_name in unet_conversion_map: 710 | mapping[hf_name] = sd_name 711 | for k, v in mapping.items(): 712 | if "resnets" in k: 713 | for sd_part, hf_part in unet_conversion_map_resnet: 714 | v = v.replace(hf_part, sd_part) 715 | mapping[k] = v 716 | for k, v in mapping.items(): 717 | for sd_part, hf_part in unet_conversion_map_layer: 718 | v = v.replace(hf_part, sd_part) 719 | mapping[k] = v 720 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} 721 | 722 | if v2: 723 | conv_transformer_to_linear(new_state_dict) 724 | 725 | return new_state_dict 726 | 727 | 728 | # ================# 729 | # VAE Conversion # 730 | # ================# 731 | 732 | def reshape_weight_for_sd(w): 733 | # convert HF linear weights to SD conv2d weights 734 | return w.reshape(*w.shape, 1, 1) 735 | 736 | 737 | def convert_vae_state_dict(vae_state_dict): 738 | vae_conversion_map = [ 739 | # (stable-diffusion, HF Diffusers) 740 | ("nin_shortcut", "conv_shortcut"), 741 | ("norm_out", "conv_norm_out"), 742 | ("mid.attn_1.", "mid_block.attentions.0."), 743 | ] 744 | 745 | for i in range(4): 746 | # down_blocks have two resnets 747 | for j in range(2): 748 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 749 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 750 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) 751 | 752 | if i < 3: 753 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 754 | sd_downsample_prefix = f"down.{i}.downsample." 755 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) 756 | 757 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 758 | sd_upsample_prefix = f"up.{3-i}.upsample." 759 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) 760 | 761 | # up_blocks have three resnets 762 | # also, up blocks in hf are numbered in reverse from sd 763 | for j in range(3): 764 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 765 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 766 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) 767 | 768 | # this part accounts for mid blocks in both the encoder and the decoder 769 | for i in range(2): 770 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 771 | sd_mid_res_prefix = f"mid.block_{i+1}." 772 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) 773 | 774 | vae_conversion_map_attn = [ 775 | # (stable-diffusion, HF Diffusers) 776 | ("norm.", "group_norm."), 777 | ("q.", "query."), 778 | ("k.", "key."), 779 | ("v.", "value."), 780 | ("proj_out.", "proj_attn."), 781 | ] 782 | 783 | mapping = {k: k for k in vae_state_dict.keys()} 784 | for k, v in mapping.items(): 785 | for sd_part, hf_part in vae_conversion_map: 786 | v = v.replace(hf_part, sd_part) 787 | mapping[k] = v 788 | for k, v in mapping.items(): 789 | if "attentions" in k: 790 | for sd_part, hf_part in vae_conversion_map_attn: 791 | v = v.replace(hf_part, sd_part) 792 | mapping[k] = v 793 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} 794 | weights_to_convert = ["q", "k", "v", "proj_out"] 795 | 796 | for k, v in new_state_dict.items(): 797 | for weight_name in weights_to_convert: 798 | if f"mid.attn_1.{weight_name}.weight" in k: 799 | new_state_dict[k] = reshape_weight_for_sd(v) 800 | 801 | return new_state_dict 802 | 803 | 804 | # endregion 805 | 806 | 807 | def load_checkpoint_with_text_encoder_conversion(ckpt_path): 808 | # text encoderの格納形式が違うモデルに対応する ('text_model'がない) 809 | TEXT_ENCODER_KEY_REPLACEMENTS = [ 810 | ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), 811 | ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), 812 | ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') 813 | ] 814 | 815 | checkpoint = torch.load(ckpt_path, map_location="cuda") 816 | state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint 817 | key_reps = [] 818 | for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: 819 | for key in state_dict.keys(): 820 | if key.startswith(rep_from): 821 | new_key = rep_to + key[len(rep_from):] 822 | key_reps.append((key, new_key)) 823 | 824 | for key, new_key in key_reps: 825 | state_dict[new_key] = state_dict[key] 826 | del state_dict[key] 827 | 828 | return checkpoint 829 | 830 | 831 | # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 832 | def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): 833 | 834 | checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) 835 | state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint 836 | if dtype is not None: 837 | for k, v in state_dict.items(): 838 | if type(v) is torch.Tensor: 839 | state_dict[k] = v.to(dtype) 840 | 841 | # Convert the UNet2DConditionModel model. 842 | unet_config = create_unet_diffusers_config(v2) 843 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) 844 | 845 | unet = UNet2DConditionModel(**unet_config) 846 | info = unet.load_state_dict(converted_unet_checkpoint) 847 | 848 | 849 | # Convert the VAE model. 850 | vae_config = create_vae_diffusers_config() 851 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) 852 | 853 | vae = AutoencoderKL(**vae_config) 854 | info = vae.load_state_dict(converted_vae_checkpoint) 855 | 856 | 857 | # convert text_model 858 | if v2: 859 | converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) 860 | cfg = CLIPTextConfig( 861 | vocab_size=49408, 862 | hidden_size=1024, 863 | intermediate_size=4096, 864 | num_hidden_layers=23, 865 | num_attention_heads=16, 866 | max_position_embeddings=77, 867 | hidden_act="gelu", 868 | layer_norm_eps=1e-05, 869 | dropout=0.0, 870 | attention_dropout=0.0, 871 | initializer_range=0.02, 872 | initializer_factor=1.0, 873 | pad_token_id=1, 874 | bos_token_id=0, 875 | eos_token_id=2, 876 | model_type="clip_text_model", 877 | projection_dim=512, 878 | torch_dtype="float32", 879 | transformers_version="4.25.0.dev0", 880 | ) 881 | text_model = CLIPTextModel._from_config(cfg) 882 | info = text_model.load_state_dict(converted_text_encoder_checkpoint) 883 | else: 884 | converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) 885 | cfg = CLIPTextConfig( 886 | vocab_size=49408, 887 | hidden_size=768, 888 | intermediate_size=3072, 889 | num_hidden_layers=12, 890 | num_attention_heads=12, 891 | max_position_embeddings=77, 892 | hidden_act="quick_gelu", 893 | layer_norm_eps=1e-05, 894 | dropout=0.0, 895 | attention_dropout=0.0, 896 | initializer_range=0.02, 897 | initializer_factor=1.0, 898 | pad_token_id=1, 899 | bos_token_id=0, 900 | eos_token_id=2, 901 | model_type="clip_text_model", 902 | projection_dim=768, 903 | torch_dtype="float32", 904 | transformers_version="4.16.0.dev0", 905 | ) 906 | 907 | 908 | text_model = CLIPTextModel._from_config(cfg) 909 | #text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") 910 | info = text_model.load_state_dict(converted_text_encoder_checkpoint) 911 | 912 | 913 | return text_model, vae, unet 914 | 915 | 916 | def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): 917 | def convert_key(key): 918 | # position_idsの除去 919 | if ".position_ids" in key: 920 | return None 921 | 922 | # common 923 | key = key.replace("text_model.encoder.", "transformer.") 924 | key = key.replace("text_model.", "") 925 | if "layers" in key: 926 | # resblocks conversion 927 | key = key.replace(".layers.", ".resblocks.") 928 | if ".layer_norm" in key: 929 | key = key.replace(".layer_norm", ".ln_") 930 | elif ".mlp." in key: 931 | key = key.replace(".fc1.", ".c_fc.") 932 | key = key.replace(".fc2.", ".c_proj.") 933 | elif '.self_attn.out_proj' in key: 934 | key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") 935 | elif '.self_attn.' in key: 936 | key = None # 特殊なので後で処理する 937 | else: 938 | raise ValueError(f"unexpected key in DiffUsers model: {key}") 939 | elif '.position_embedding' in key: 940 | key = key.replace("embeddings.position_embedding.weight", "positional_embedding") 941 | elif '.token_embedding' in key: 942 | key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") 943 | elif 'final_layer_norm' in key: 944 | key = key.replace("final_layer_norm", "ln_final") 945 | return key 946 | 947 | keys = list(checkpoint.keys()) 948 | new_sd = {} 949 | for key in keys: 950 | new_key = convert_key(key) 951 | if new_key is None: 952 | continue 953 | new_sd[new_key] = checkpoint[key] 954 | 955 | # attnの変換 956 | for key in keys: 957 | if 'layers' in key and 'q_proj' in key: 958 | # 三つを結合 959 | key_q = key 960 | key_k = key.replace("q_proj", "k_proj") 961 | key_v = key.replace("q_proj", "v_proj") 962 | 963 | value_q = checkpoint[key_q] 964 | value_k = checkpoint[key_k] 965 | value_v = checkpoint[key_v] 966 | value = torch.cat([value_q, value_k, value_v]) 967 | 968 | new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") 969 | new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") 970 | new_sd[new_key] = value 971 | 972 | # 最後の層などを捏造するか 973 | if make_dummy_weights: 974 | 975 | keys = list(new_sd.keys()) 976 | for key in keys: 977 | if key.startswith("transformer.resblocks.22."): 978 | new_sd[key.replace(".22.", ".23.")] = new_sd[key] 979 | 980 | # Diffusersに含まれない重みを作っておく 981 | new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) 982 | new_sd['logit_scale'] = torch.tensor(1) 983 | 984 | return new_sd 985 | 986 | 987 | def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): 988 | if ckpt_path is not None: 989 | # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む 990 | checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) 991 | state_dict = checkpoint["state_dict"] 992 | strict = True 993 | else: 994 | # 新しく作る 995 | checkpoint = {} 996 | state_dict = {} 997 | strict = False 998 | 999 | def update_sd(prefix, sd): 1000 | for k, v in sd.items(): 1001 | key = prefix + k 1002 | assert not strict or key in state_dict, f"Illegal key in save SD: {key}" 1003 | if save_dtype is not None: 1004 | v = v.detach().clone().to("cpu").to(save_dtype) 1005 | state_dict[key] = v 1006 | 1007 | # Convert the UNet model 1008 | unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) 1009 | update_sd("model.diffusion_model.", unet_state_dict) 1010 | 1011 | # Convert the text encoder model 1012 | if v2: 1013 | make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる 1014 | text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) 1015 | update_sd("cond_stage_model.model.", text_enc_dict) 1016 | else: 1017 | text_enc_dict = text_encoder.state_dict() 1018 | update_sd("cond_stage_model.transformer.", text_enc_dict) 1019 | 1020 | # Convert the VAE 1021 | if vae is not None: 1022 | vae_dict = convert_vae_state_dict(vae.state_dict()) 1023 | update_sd("first_stage_model.", vae_dict) 1024 | 1025 | # Put together new checkpoint 1026 | key_count = len(state_dict.keys()) 1027 | new_ckpt = {'state_dict': state_dict} 1028 | 1029 | if 'epoch' in checkpoint: 1030 | epochs += checkpoint['epoch'] 1031 | if 'global_step' in checkpoint: 1032 | steps += checkpoint['global_step'] 1033 | 1034 | new_ckpt['epoch'] = epochs 1035 | new_ckpt['global_step'] = steps 1036 | 1037 | torch.save(new_ckpt, output_file) 1038 | 1039 | return key_count 1040 | 1041 | 1042 | def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, vae=None): 1043 | if vae is None: 1044 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") 1045 | 1046 | pipeline = StableDiffusionPipeline( 1047 | unet=unet, 1048 | text_encoder=text_encoder, 1049 | vae=vae, 1050 | scheduler = DDIMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler"), 1051 | tokenizer=CLIPTokenizer.from_pretrained("refmdl", subfolder="tokenizer"), 1052 | ) 1053 | pipeline.save_pretrained(output_dir) 1054 | 1055 | 1056 | 1057 | def convert(args): 1058 | print("Converting to Diffusers ...") 1059 | load_dtype = torch.float16 if args.fp16 else None 1060 | 1061 | save_dtype = None 1062 | if args.fp16: 1063 | save_dtype = torch.float16 1064 | elif args.bf16: 1065 | save_dtype = torch.bfloat16 1066 | elif args.float: 1067 | save_dtype = torch.float 1068 | 1069 | is_load_ckpt = os.path.isfile(args.model_to_load) 1070 | is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 1071 | 1072 | assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint" 1073 | assert is_save_ckpt is not None, f"reference model is required to save as Diffusers" 1074 | 1075 | # モデルを読み込む 1076 | msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) 1077 | 1078 | 1079 | if is_load_ckpt: 1080 | v2_model = args.v2 1081 | text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load) 1082 | else: 1083 | pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None) 1084 | text_encoder = pipe.text_encoder 1085 | vae = pipe.vae 1086 | unet = pipe.unet 1087 | 1088 | if args.v1 == args.v2: 1089 | # 自動判定する 1090 | v2_model = unet.config.cross_attention_dim == 1024 1091 | #print("checking model version: model is " + ('v2' if v2_model else 'v1')) 1092 | else: 1093 | v2_model = args.v1 1094 | 1095 | # 変換して保存する 1096 | msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" 1097 | 1098 | 1099 | if is_save_ckpt: 1100 | original_model = args.model_to_load if is_load_ckpt else None 1101 | key_count = save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet, 1102 | original_model, args.epoch, args.global_step, save_dtype, vae) 1103 | 1104 | else: 1105 | save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, vae) 1106 | 1107 | 1108 | 1109 | if __name__ == '__main__': 1110 | parser = argparse.ArgumentParser() 1111 | parser.add_argument("--v1", action='store_true', 1112 | help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む') 1113 | parser.add_argument("--v2", action='store_true', 1114 | help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む') 1115 | parser.add_argument("--fp16", action='store_true', 1116 | help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)') 1117 | parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)') 1118 | parser.add_argument("--float", action='store_true', 1119 | help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)') 1120 | parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値') 1121 | parser.add_argument("--global_step", type=int, default=0, 1122 | help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値') 1123 | 1124 | parser.add_argument("model_to_load", type=str, default=None, 1125 | help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ") 1126 | parser.add_argument("model_to_save", type=str, default=None, 1127 | help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存") 1128 | 1129 | args = parser.parse_args() 1130 | convert(args) 1131 | -------------------------------------------------------------------------------- /Dreambooth/convertosd.py: -------------------------------------------------------------------------------- 1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. 2 | # *Only* converts the UNet, VAE, and Text Encoder. 3 | # Does not convert optimizer state or any other thing. 4 | # Written by jachiam 5 | 6 | import argparse 7 | import os.path as osp 8 | 9 | import torch 10 | 11 | 12 | # =================# 13 | # UNet Conversion # 14 | # =================# 15 | 16 | unet_conversion_map = [ 17 | # (stable-diffusion, HF Diffusers) 18 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), 19 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), 20 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), 21 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), 22 | ("input_blocks.0.0.weight", "conv_in.weight"), 23 | ("input_blocks.0.0.bias", "conv_in.bias"), 24 | ("out.0.weight", "conv_norm_out.weight"), 25 | ("out.0.bias", "conv_norm_out.bias"), 26 | ("out.2.weight", "conv_out.weight"), 27 | ("out.2.bias", "conv_out.bias"), 28 | ] 29 | 30 | unet_conversion_map_resnet = [ 31 | # (stable-diffusion, HF Diffusers) 32 | ("in_layers.0", "norm1"), 33 | ("in_layers.2", "conv1"), 34 | ("out_layers.0", "norm2"), 35 | ("out_layers.3", "conv2"), 36 | ("emb_layers.1", "time_emb_proj"), 37 | ("skip_connection", "conv_shortcut"), 38 | ] 39 | 40 | unet_conversion_map_layer = [] 41 | # hardcoded number of downblocks and resnets/attentions... 42 | # would need smarter logic for other networks. 43 | for i in range(4): 44 | # loop over downblocks/upblocks 45 | 46 | for j in range(2): 47 | # loop over resnets/attentions for downblocks 48 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 49 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 50 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 51 | 52 | if i < 3: 53 | # no attention layers in down_blocks.3 54 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 55 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 56 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 57 | 58 | for j in range(3): 59 | # loop over resnets/attentions for upblocks 60 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 61 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 62 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 63 | 64 | if i > 0: 65 | # no attention layers in up_blocks.0 66 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 67 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 68 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 69 | 70 | if i < 3: 71 | # no downsample in down_blocks.3 72 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 73 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 74 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 75 | 76 | # no upsample in up_blocks.3 77 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 78 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." 79 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 80 | 81 | hf_mid_atn_prefix = "mid_block.attentions.0." 82 | sd_mid_atn_prefix = "middle_block.1." 83 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 84 | 85 | for j in range(2): 86 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 87 | sd_mid_res_prefix = f"middle_block.{2*j}." 88 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 89 | 90 | 91 | def convert_unet_state_dict(unet_state_dict): 92 | # buyer beware: this is a *brittle* function, 93 | # and correct output requires that all of these pieces interact in 94 | # the exact order in which I have arranged them. 95 | mapping = {k: k for k in unet_state_dict.keys()} 96 | for sd_name, hf_name in unet_conversion_map: 97 | mapping[hf_name] = sd_name 98 | for k, v in mapping.items(): 99 | if "resnets" in k: 100 | for sd_part, hf_part in unet_conversion_map_resnet: 101 | v = v.replace(hf_part, sd_part) 102 | mapping[k] = v 103 | for k, v in mapping.items(): 104 | for sd_part, hf_part in unet_conversion_map_layer: 105 | v = v.replace(hf_part, sd_part) 106 | mapping[k] = v 107 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} 108 | return new_state_dict 109 | 110 | 111 | # ================# 112 | # VAE Conversion # 113 | # ================# 114 | 115 | vae_conversion_map = [ 116 | # (stable-diffusion, HF Diffusers) 117 | ("nin_shortcut", "conv_shortcut"), 118 | ("norm_out", "conv_norm_out"), 119 | ("mid.attn_1.", "mid_block.attentions.0."), 120 | ] 121 | 122 | for i in range(4): 123 | # down_blocks have two resnets 124 | for j in range(2): 125 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 126 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 127 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) 128 | 129 | if i < 3: 130 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 131 | sd_downsample_prefix = f"down.{i}.downsample." 132 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) 133 | 134 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 135 | sd_upsample_prefix = f"up.{3-i}.upsample." 136 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) 137 | 138 | # up_blocks have three resnets 139 | # also, up blocks in hf are numbered in reverse from sd 140 | for j in range(3): 141 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 142 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 143 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) 144 | 145 | # this part accounts for mid blocks in both the encoder and the decoder 146 | for i in range(2): 147 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 148 | sd_mid_res_prefix = f"mid.block_{i+1}." 149 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) 150 | 151 | 152 | vae_conversion_map_attn = [ 153 | # (stable-diffusion, HF Diffusers) 154 | ("norm.", "group_norm."), 155 | ("q.", "query."), 156 | ("k.", "key."), 157 | ("v.", "value."), 158 | ("proj_out.", "proj_attn."), 159 | ] 160 | 161 | 162 | def reshape_weight_for_sd(w): 163 | # convert HF linear weights to SD conv2d weights 164 | return w.reshape(*w.shape, 1, 1) 165 | 166 | 167 | def convert_vae_state_dict(vae_state_dict): 168 | mapping = {k: k for k in vae_state_dict.keys()} 169 | for k, v in mapping.items(): 170 | for sd_part, hf_part in vae_conversion_map: 171 | v = v.replace(hf_part, sd_part) 172 | mapping[k] = v 173 | for k, v in mapping.items(): 174 | if "attentions" in k: 175 | for sd_part, hf_part in vae_conversion_map_attn: 176 | v = v.replace(hf_part, sd_part) 177 | mapping[k] = v 178 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} 179 | weights_to_convert = ["q", "k", "v", "proj_out"] 180 | print("Converting to CKPT ...") 181 | for k, v in new_state_dict.items(): 182 | for weight_name in weights_to_convert: 183 | if f"mid.attn_1.{weight_name}.weight" in k: 184 | new_state_dict[k] = reshape_weight_for_sd(v) 185 | return new_state_dict 186 | 187 | 188 | # =========================# 189 | # Text Encoder Conversion # 190 | # =========================# 191 | # pretty much a no-op 192 | 193 | 194 | def convert_text_enc_state_dict(text_enc_dict): 195 | return text_enc_dict 196 | 197 | 198 | if __name__ == "__main__": 199 | 200 | 201 | model_path = "" 202 | checkpoint_path= "" 203 | 204 | unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") 205 | vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") 206 | text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") 207 | 208 | # Convert the UNet model 209 | unet_state_dict = torch.load(unet_path, map_location='cpu') 210 | unet_state_dict = convert_unet_state_dict(unet_state_dict) 211 | unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} 212 | 213 | # Convert the VAE model 214 | vae_state_dict = torch.load(vae_path, map_location='cpu') 215 | vae_state_dict = convert_vae_state_dict(vae_state_dict) 216 | vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} 217 | 218 | # Convert the text encoder model 219 | text_enc_dict = torch.load(text_enc_path, map_location='cpu') 220 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict) 221 | text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} 222 | 223 | # Put together new checkpoint 224 | state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} 225 | 226 | state_dict = {k:v.half() for k,v in state_dict.items()} 227 | state_dict = {"state_dict": state_dict} 228 | torch.save(state_dict, checkpoint_path) 229 | -------------------------------------------------------------------------------- /Dreambooth/convertosdv2.py: -------------------------------------------------------------------------------- 1 | print("Converting to CKPT ...") 2 | import argparse 3 | import os 4 | import torch 5 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig 6 | from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel 7 | 8 | 9 | # DiffUsers版StableDiffusionのモデルパラメータ 10 | NUM_TRAIN_TIMESTEPS = 1000 11 | BETA_START = 0.00085 12 | BETA_END = 0.0120 13 | 14 | UNET_PARAMS_MODEL_CHANNELS = 320 15 | UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] 16 | UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] 17 | UNET_PARAMS_IMAGE_SIZE = 64 18 | UNET_PARAMS_IN_CHANNELS = 4 19 | UNET_PARAMS_OUT_CHANNELS = 4 20 | UNET_PARAMS_NUM_RES_BLOCKS = 2 21 | UNET_PARAMS_CONTEXT_DIM = 768 22 | UNET_PARAMS_NUM_HEADS = 8 23 | 24 | VAE_PARAMS_Z_CHANNELS = 4 25 | VAE_PARAMS_RESOLUTION = 768 26 | VAE_PARAMS_IN_CHANNELS = 3 27 | VAE_PARAMS_OUT_CH = 3 28 | VAE_PARAMS_CH = 128 29 | VAE_PARAMS_CH_MULT = [1, 2, 4, 4] 30 | VAE_PARAMS_NUM_RES_BLOCKS = 2 31 | 32 | # V2 33 | V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] 34 | V2_UNET_PARAMS_CONTEXT_DIM = 1024 35 | 36 | 37 | # region StableDiffusion->Diffusersの変換コード 38 | # convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0) 39 | 40 | 41 | def shave_segments(path, n_shave_prefix_segments=1): 42 | """ 43 | Removes segments. Positive values shave the first segments, negative shave the last segments. 44 | """ 45 | if n_shave_prefix_segments >= 0: 46 | return ".".join(path.split(".")[n_shave_prefix_segments:]) 47 | else: 48 | return ".".join(path.split(".")[:n_shave_prefix_segments]) 49 | 50 | 51 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): 52 | """ 53 | Updates paths inside resnets to the new naming scheme (local renaming) 54 | """ 55 | mapping = [] 56 | for old_item in old_list: 57 | new_item = old_item.replace("in_layers.0", "norm1") 58 | new_item = new_item.replace("in_layers.2", "conv1") 59 | 60 | new_item = new_item.replace("out_layers.0", "norm2") 61 | new_item = new_item.replace("out_layers.3", "conv2") 62 | 63 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") 64 | new_item = new_item.replace("skip_connection", "conv_shortcut") 65 | 66 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 67 | 68 | mapping.append({"old": old_item, "new": new_item}) 69 | 70 | return mapping 71 | 72 | 73 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): 74 | """ 75 | Updates paths inside resnets to the new naming scheme (local renaming) 76 | """ 77 | mapping = [] 78 | for old_item in old_list: 79 | new_item = old_item 80 | 81 | new_item = new_item.replace("nin_shortcut", "conv_shortcut") 82 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 83 | 84 | mapping.append({"old": old_item, "new": new_item}) 85 | 86 | return mapping 87 | 88 | 89 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): 90 | """ 91 | Updates paths inside attentions to the new naming scheme (local renaming) 92 | """ 93 | mapping = [] 94 | for old_item in old_list: 95 | new_item = old_item 96 | 97 | # new_item = new_item.replace('norm.weight', 'group_norm.weight') 98 | # new_item = new_item.replace('norm.bias', 'group_norm.bias') 99 | 100 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') 101 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') 102 | 103 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 104 | 105 | mapping.append({"old": old_item, "new": new_item}) 106 | 107 | return mapping 108 | 109 | 110 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): 111 | """ 112 | Updates paths inside attentions to the new naming scheme (local renaming) 113 | """ 114 | mapping = [] 115 | for old_item in old_list: 116 | new_item = old_item 117 | 118 | new_item = new_item.replace("norm.weight", "group_norm.weight") 119 | new_item = new_item.replace("norm.bias", "group_norm.bias") 120 | 121 | new_item = new_item.replace("q.weight", "query.weight") 122 | new_item = new_item.replace("q.bias", "query.bias") 123 | 124 | new_item = new_item.replace("k.weight", "key.weight") 125 | new_item = new_item.replace("k.bias", "key.bias") 126 | 127 | new_item = new_item.replace("v.weight", "value.weight") 128 | new_item = new_item.replace("v.bias", "value.bias") 129 | 130 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight") 131 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias") 132 | 133 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 134 | 135 | mapping.append({"old": old_item, "new": new_item}) 136 | 137 | return mapping 138 | 139 | 140 | def assign_to_checkpoint( 141 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None 142 | ): 143 | """ 144 | This does the final conversion step: take locally converted weights and apply a global renaming 145 | to them. It splits attention layers, and takes into account additional replacements 146 | that may arise. 147 | 148 | Assigns the weights to the new checkpoint. 149 | """ 150 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." 151 | 152 | # Splits the attention layers into three variables. 153 | if attention_paths_to_split is not None: 154 | for path, path_map in attention_paths_to_split.items(): 155 | old_tensor = old_checkpoint[path] 156 | channels = old_tensor.shape[0] // 3 157 | 158 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) 159 | 160 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 161 | 162 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) 163 | query, key, value = old_tensor.split(channels // num_heads, dim=1) 164 | 165 | checkpoint[path_map["query"]] = query.reshape(target_shape) 166 | checkpoint[path_map["key"]] = key.reshape(target_shape) 167 | checkpoint[path_map["value"]] = value.reshape(target_shape) 168 | 169 | for path in paths: 170 | new_path = path["new"] 171 | 172 | # These have already been assigned 173 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: 174 | continue 175 | 176 | # Global renaming happens here 177 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") 178 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") 179 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") 180 | 181 | if additional_replacements is not None: 182 | for replacement in additional_replacements: 183 | new_path = new_path.replace(replacement["old"], replacement["new"]) 184 | 185 | # proj_attn.weight has to be converted from conv 1D to linear 186 | if "proj_attn.weight" in new_path: 187 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] 188 | else: 189 | checkpoint[new_path] = old_checkpoint[path["old"]] 190 | 191 | 192 | def conv_attn_to_linear(checkpoint): 193 | keys = list(checkpoint.keys()) 194 | attn_keys = ["query.weight", "key.weight", "value.weight"] 195 | for key in keys: 196 | if ".".join(key.split(".")[-2:]) in attn_keys: 197 | if checkpoint[key].ndim > 2: 198 | checkpoint[key] = checkpoint[key][:, :, 0, 0] 199 | elif "proj_attn.weight" in key: 200 | if checkpoint[key].ndim > 2: 201 | checkpoint[key] = checkpoint[key][:, :, 0] 202 | 203 | 204 | def linear_transformer_to_conv(checkpoint): 205 | keys = list(checkpoint.keys()) 206 | tf_keys = ["proj_in.weight", "proj_out.weight"] 207 | for key in keys: 208 | if ".".join(key.split(".")[-2:]) in tf_keys: 209 | if checkpoint[key].ndim == 2: 210 | checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) 211 | 212 | 213 | def convert_ldm_unet_checkpoint(v2, checkpoint, config): 214 | """ 215 | Takes a state dict and a config, and returns a converted checkpoint. 216 | """ 217 | 218 | # extract state_dict for UNet 219 | unet_state_dict = {} 220 | unet_key = "model.diffusion_model." 221 | keys = list(checkpoint.keys()) 222 | for key in keys: 223 | if key.startswith(unet_key): 224 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) 225 | 226 | new_checkpoint = {} 227 | 228 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] 229 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] 230 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] 231 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] 232 | 233 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] 234 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] 235 | 236 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] 237 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] 238 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] 239 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] 240 | 241 | # Retrieves the keys for the input blocks only 242 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) 243 | input_blocks = { 244 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] 245 | for layer_id in range(num_input_blocks) 246 | } 247 | 248 | # Retrieves the keys for the middle blocks only 249 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) 250 | middle_blocks = { 251 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] 252 | for layer_id in range(num_middle_blocks) 253 | } 254 | 255 | # Retrieves the keys for the output blocks only 256 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) 257 | output_blocks = { 258 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] 259 | for layer_id in range(num_output_blocks) 260 | } 261 | 262 | for i in range(1, num_input_blocks): 263 | block_id = (i - 1) // (config["layers_per_block"] + 1) 264 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) 265 | 266 | resnets = [ 267 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key 268 | ] 269 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] 270 | 271 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict: 272 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( 273 | f"input_blocks.{i}.0.op.weight" 274 | ) 275 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( 276 | f"input_blocks.{i}.0.op.bias" 277 | ) 278 | 279 | paths = renew_resnet_paths(resnets) 280 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} 281 | assign_to_checkpoint( 282 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 283 | ) 284 | 285 | if len(attentions): 286 | paths = renew_attention_paths(attentions) 287 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} 288 | assign_to_checkpoint( 289 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 290 | ) 291 | 292 | resnet_0 = middle_blocks[0] 293 | attentions = middle_blocks[1] 294 | resnet_1 = middle_blocks[2] 295 | 296 | resnet_0_paths = renew_resnet_paths(resnet_0) 297 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) 298 | 299 | resnet_1_paths = renew_resnet_paths(resnet_1) 300 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) 301 | 302 | attentions_paths = renew_attention_paths(attentions) 303 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} 304 | assign_to_checkpoint( 305 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 306 | ) 307 | 308 | for i in range(num_output_blocks): 309 | block_id = i // (config["layers_per_block"] + 1) 310 | layer_in_block_id = i % (config["layers_per_block"] + 1) 311 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] 312 | output_block_list = {} 313 | 314 | for layer in output_block_layers: 315 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) 316 | if layer_id in output_block_list: 317 | output_block_list[layer_id].append(layer_name) 318 | else: 319 | output_block_list[layer_id] = [layer_name] 320 | 321 | if len(output_block_list) > 1: 322 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] 323 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] 324 | 325 | resnet_0_paths = renew_resnet_paths(resnets) 326 | paths = renew_resnet_paths(resnets) 327 | 328 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} 329 | assign_to_checkpoint( 330 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 331 | ) 332 | 333 | if ["conv.weight", "conv.bias"] in output_block_list.values(): 334 | index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) 335 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ 336 | f"output_blocks.{i}.{index}.conv.weight" 337 | ] 338 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ 339 | f"output_blocks.{i}.{index}.conv.bias" 340 | ] 341 | 342 | # Clear attentions as they have been attributed above. 343 | if len(attentions) == 2: 344 | attentions = [] 345 | 346 | if len(attentions): 347 | paths = renew_attention_paths(attentions) 348 | meta_path = { 349 | "old": f"output_blocks.{i}.1", 350 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", 351 | } 352 | assign_to_checkpoint( 353 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 354 | ) 355 | else: 356 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) 357 | for path in resnet_0_paths: 358 | old_path = ".".join(["output_blocks", str(i), path["old"]]) 359 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) 360 | 361 | new_checkpoint[new_path] = unet_state_dict[old_path] 362 | 363 | # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する 364 | if v2: 365 | linear_transformer_to_conv(new_checkpoint) 366 | 367 | return new_checkpoint 368 | 369 | 370 | def convert_ldm_vae_checkpoint(checkpoint, config): 371 | # extract state dict for VAE 372 | vae_state_dict = {} 373 | vae_key = "first_stage_model." 374 | keys = list(checkpoint.keys()) 375 | for key in keys: 376 | if key.startswith(vae_key): 377 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) 378 | # if len(vae_state_dict) == 0: 379 | # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict 380 | # vae_state_dict = checkpoint 381 | 382 | new_checkpoint = {} 383 | 384 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] 385 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] 386 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] 387 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] 388 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] 389 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] 390 | 391 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] 392 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] 393 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] 394 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] 395 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] 396 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] 397 | 398 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] 399 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] 400 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] 401 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] 402 | 403 | # Retrieves the keys for the encoder down blocks only 404 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) 405 | down_blocks = { 406 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) 407 | } 408 | 409 | # Retrieves the keys for the decoder up blocks only 410 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) 411 | up_blocks = { 412 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) 413 | } 414 | 415 | for i in range(num_down_blocks): 416 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] 417 | 418 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: 419 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( 420 | f"encoder.down.{i}.downsample.conv.weight" 421 | ) 422 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( 423 | f"encoder.down.{i}.downsample.conv.bias" 424 | ) 425 | 426 | paths = renew_vae_resnet_paths(resnets) 427 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} 428 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 429 | 430 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] 431 | num_mid_res_blocks = 2 432 | for i in range(1, num_mid_res_blocks + 1): 433 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] 434 | 435 | paths = renew_vae_resnet_paths(resnets) 436 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 437 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 438 | 439 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] 440 | paths = renew_vae_attention_paths(mid_attentions) 441 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 442 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 443 | conv_attn_to_linear(new_checkpoint) 444 | 445 | for i in range(num_up_blocks): 446 | block_id = num_up_blocks - 1 - i 447 | resnets = [ 448 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key 449 | ] 450 | 451 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: 452 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ 453 | f"decoder.up.{block_id}.upsample.conv.weight" 454 | ] 455 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ 456 | f"decoder.up.{block_id}.upsample.conv.bias" 457 | ] 458 | 459 | paths = renew_vae_resnet_paths(resnets) 460 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} 461 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 462 | 463 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] 464 | num_mid_res_blocks = 2 465 | for i in range(1, num_mid_res_blocks + 1): 466 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] 467 | 468 | paths = renew_vae_resnet_paths(resnets) 469 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 470 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 471 | 472 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] 473 | paths = renew_vae_attention_paths(mid_attentions) 474 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 475 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 476 | conv_attn_to_linear(new_checkpoint) 477 | return new_checkpoint 478 | 479 | 480 | def create_unet_diffusers_config(v2): 481 | """ 482 | Creates a config for the diffusers based on the config of the LDM model. 483 | """ 484 | # unet_params = original_config.model.params.unet_config.params 485 | 486 | block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] 487 | 488 | down_block_types = [] 489 | resolution = 1 490 | for i in range(len(block_out_channels)): 491 | block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" 492 | down_block_types.append(block_type) 493 | if i != len(block_out_channels) - 1: 494 | resolution *= 2 495 | 496 | up_block_types = [] 497 | for i in range(len(block_out_channels)): 498 | block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" 499 | up_block_types.append(block_type) 500 | resolution //= 2 501 | 502 | config = dict( 503 | sample_size=UNET_PARAMS_IMAGE_SIZE, 504 | in_channels=UNET_PARAMS_IN_CHANNELS, 505 | out_channels=UNET_PARAMS_OUT_CHANNELS, 506 | down_block_types=tuple(down_block_types), 507 | up_block_types=tuple(up_block_types), 508 | block_out_channels=tuple(block_out_channels), 509 | layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, 510 | cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, 511 | attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, 512 | ) 513 | 514 | return config 515 | 516 | 517 | def create_vae_diffusers_config(): 518 | """ 519 | Creates a config for the diffusers based on the config of the LDM model. 520 | """ 521 | # vae_params = original_config.model.params.first_stage_config.params.ddconfig 522 | # _ = original_config.model.params.first_stage_config.params.embed_dim 523 | block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] 524 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) 525 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) 526 | 527 | config = dict( 528 | sample_size=VAE_PARAMS_RESOLUTION, 529 | in_channels=VAE_PARAMS_IN_CHANNELS, 530 | out_channels=VAE_PARAMS_OUT_CH, 531 | down_block_types=tuple(down_block_types), 532 | up_block_types=tuple(up_block_types), 533 | block_out_channels=tuple(block_out_channels), 534 | latent_channels=VAE_PARAMS_Z_CHANNELS, 535 | layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, 536 | ) 537 | return config 538 | 539 | 540 | def convert_ldm_clip_checkpoint_v1(checkpoint): 541 | keys = list(checkpoint.keys()) 542 | text_model_dict = {} 543 | for key in keys: 544 | if key.startswith("cond_stage_model.transformer"): 545 | text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] 546 | return text_model_dict 547 | 548 | 549 | def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): 550 | # 嫌になるくらい違うぞ! 551 | def convert_key(key): 552 | if not key.startswith("cond_stage_model"): 553 | return None 554 | 555 | # common conversion 556 | key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") 557 | key = key.replace("cond_stage_model.model.", "text_model.") 558 | 559 | if "resblocks" in key: 560 | # resblocks conversion 561 | key = key.replace(".resblocks.", ".layers.") 562 | if ".ln_" in key: 563 | key = key.replace(".ln_", ".layer_norm") 564 | elif ".mlp." in key: 565 | key = key.replace(".c_fc.", ".fc1.") 566 | key = key.replace(".c_proj.", ".fc2.") 567 | elif '.attn.out_proj' in key: 568 | key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") 569 | elif '.attn.in_proj' in key: 570 | key = None # 特殊なので後で処理する 571 | else: 572 | raise ValueError(f"unexpected key in SD: {key}") 573 | elif '.positional_embedding' in key: 574 | key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") 575 | elif '.text_projection' in key: 576 | key = None # 使われない??? 577 | elif '.logit_scale' in key: 578 | key = None # 使われない??? 579 | elif '.token_embedding' in key: 580 | key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") 581 | elif '.ln_final' in key: 582 | key = key.replace(".ln_final", ".final_layer_norm") 583 | return key 584 | 585 | keys = list(checkpoint.keys()) 586 | new_sd = {} 587 | for key in keys: 588 | # remove resblocks 23 589 | if '.resblocks.23.' in key: 590 | continue 591 | new_key = convert_key(key) 592 | if new_key is None: 593 | continue 594 | new_sd[new_key] = checkpoint[key] 595 | 596 | # attnの変換 597 | for key in keys: 598 | if '.resblocks.23.' in key: 599 | continue 600 | if '.resblocks' in key and '.attn.in_proj_' in key: 601 | # 三つに分割 602 | values = torch.chunk(checkpoint[key], 3) 603 | 604 | key_suffix = ".weight" if "weight" in key else ".bias" 605 | key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") 606 | key_pfx = key_pfx.replace("_weight", "") 607 | key_pfx = key_pfx.replace("_bias", "") 608 | key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") 609 | new_sd[key_pfx + "q_proj" + key_suffix] = values[0] 610 | new_sd[key_pfx + "k_proj" + key_suffix] = values[1] 611 | new_sd[key_pfx + "v_proj" + key_suffix] = values[2] 612 | 613 | # position_idsの追加 614 | new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64) 615 | return new_sd 616 | 617 | # endregion 618 | 619 | 620 | # region Diffusers->StableDiffusion の変換コード 621 | # convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0) 622 | 623 | def conv_transformer_to_linear(checkpoint): 624 | keys = list(checkpoint.keys()) 625 | tf_keys = ["proj_in.weight", "proj_out.weight"] 626 | for key in keys: 627 | if ".".join(key.split(".")[-2:]) in tf_keys: 628 | if checkpoint[key].ndim > 2: 629 | checkpoint[key] = checkpoint[key][:, :, 0, 0] 630 | 631 | 632 | def convert_unet_state_dict_to_sd(v2, unet_state_dict): 633 | unet_conversion_map = [ 634 | # (stable-diffusion, HF Diffusers) 635 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), 636 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), 637 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), 638 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), 639 | ("input_blocks.0.0.weight", "conv_in.weight"), 640 | ("input_blocks.0.0.bias", "conv_in.bias"), 641 | ("out.0.weight", "conv_norm_out.weight"), 642 | ("out.0.bias", "conv_norm_out.bias"), 643 | ("out.2.weight", "conv_out.weight"), 644 | ("out.2.bias", "conv_out.bias"), 645 | ] 646 | 647 | unet_conversion_map_resnet = [ 648 | # (stable-diffusion, HF Diffusers) 649 | ("in_layers.0", "norm1"), 650 | ("in_layers.2", "conv1"), 651 | ("out_layers.0", "norm2"), 652 | ("out_layers.3", "conv2"), 653 | ("emb_layers.1", "time_emb_proj"), 654 | ("skip_connection", "conv_shortcut"), 655 | ] 656 | 657 | unet_conversion_map_layer = [] 658 | for i in range(4): 659 | # loop over downblocks/upblocks 660 | 661 | for j in range(2): 662 | # loop over resnets/attentions for downblocks 663 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 664 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 665 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 666 | 667 | if i < 3: 668 | # no attention layers in down_blocks.3 669 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 670 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 671 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 672 | 673 | for j in range(3): 674 | # loop over resnets/attentions for upblocks 675 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 676 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 677 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 678 | 679 | if i > 0: 680 | # no attention layers in up_blocks.0 681 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 682 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 683 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 684 | 685 | if i < 3: 686 | # no downsample in down_blocks.3 687 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 688 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 689 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 690 | 691 | # no upsample in up_blocks.3 692 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 693 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." 694 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 695 | 696 | hf_mid_atn_prefix = "mid_block.attentions.0." 697 | sd_mid_atn_prefix = "middle_block.1." 698 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 699 | 700 | for j in range(2): 701 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 702 | sd_mid_res_prefix = f"middle_block.{2*j}." 703 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 704 | 705 | # buyer beware: this is a *brittle* function, 706 | # and correct output requires that all of these pieces interact in 707 | # the exact order in which I have arranged them. 708 | mapping = {k: k for k in unet_state_dict.keys()} 709 | for sd_name, hf_name in unet_conversion_map: 710 | mapping[hf_name] = sd_name 711 | for k, v in mapping.items(): 712 | if "resnets" in k: 713 | for sd_part, hf_part in unet_conversion_map_resnet: 714 | v = v.replace(hf_part, sd_part) 715 | mapping[k] = v 716 | for k, v in mapping.items(): 717 | for sd_part, hf_part in unet_conversion_map_layer: 718 | v = v.replace(hf_part, sd_part) 719 | mapping[k] = v 720 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} 721 | 722 | if v2: 723 | conv_transformer_to_linear(new_state_dict) 724 | 725 | return new_state_dict 726 | 727 | 728 | # ================# 729 | # VAE Conversion # 730 | # ================# 731 | 732 | def reshape_weight_for_sd(w): 733 | # convert HF linear weights to SD conv2d weights 734 | return w.reshape(*w.shape, 1, 1) 735 | 736 | 737 | def convert_vae_state_dict(vae_state_dict): 738 | vae_conversion_map = [ 739 | # (stable-diffusion, HF Diffusers) 740 | ("nin_shortcut", "conv_shortcut"), 741 | ("norm_out", "conv_norm_out"), 742 | ("mid.attn_1.", "mid_block.attentions.0."), 743 | ] 744 | 745 | for i in range(4): 746 | # down_blocks have two resnets 747 | for j in range(2): 748 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 749 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 750 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) 751 | 752 | if i < 3: 753 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 754 | sd_downsample_prefix = f"down.{i}.downsample." 755 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) 756 | 757 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 758 | sd_upsample_prefix = f"up.{3-i}.upsample." 759 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) 760 | 761 | # up_blocks have three resnets 762 | # also, up blocks in hf are numbered in reverse from sd 763 | for j in range(3): 764 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 765 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 766 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) 767 | 768 | # this part accounts for mid blocks in both the encoder and the decoder 769 | for i in range(2): 770 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 771 | sd_mid_res_prefix = f"mid.block_{i+1}." 772 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) 773 | 774 | vae_conversion_map_attn = [ 775 | # (stable-diffusion, HF Diffusers) 776 | ("norm.", "group_norm."), 777 | ("q.", "query."), 778 | ("k.", "key."), 779 | ("v.", "value."), 780 | ("proj_out.", "proj_attn."), 781 | ] 782 | 783 | mapping = {k: k for k in vae_state_dict.keys()} 784 | for k, v in mapping.items(): 785 | for sd_part, hf_part in vae_conversion_map: 786 | v = v.replace(hf_part, sd_part) 787 | mapping[k] = v 788 | for k, v in mapping.items(): 789 | if "attentions" in k: 790 | for sd_part, hf_part in vae_conversion_map_attn: 791 | v = v.replace(hf_part, sd_part) 792 | mapping[k] = v 793 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} 794 | weights_to_convert = ["q", "k", "v", "proj_out"] 795 | 796 | for k, v in new_state_dict.items(): 797 | for weight_name in weights_to_convert: 798 | if f"mid.attn_1.{weight_name}.weight" in k: 799 | new_state_dict[k] = reshape_weight_for_sd(v) 800 | 801 | return new_state_dict 802 | 803 | 804 | # endregion 805 | 806 | 807 | def load_checkpoint_with_text_encoder_conversion(ckpt_path): 808 | # text encoderの格納形式が違うモデルに対応する ('text_model'がない) 809 | TEXT_ENCODER_KEY_REPLACEMENTS = [ 810 | ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), 811 | ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), 812 | ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') 813 | ] 814 | 815 | checkpoint = torch.load(ckpt_path, map_location="cpu") 816 | state_dict = checkpoint["state_dict"] 817 | key_reps = [] 818 | for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: 819 | for key in state_dict.keys(): 820 | if key.startswith(rep_from): 821 | new_key = rep_to + key[len(rep_from):] 822 | key_reps.append((key, new_key)) 823 | 824 | for key, new_key in key_reps: 825 | state_dict[new_key] = state_dict[key] 826 | del state_dict[key] 827 | 828 | return checkpoint 829 | 830 | 831 | # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 832 | def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): 833 | 834 | checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) 835 | state_dict = checkpoint["state_dict"] 836 | if dtype is not None: 837 | for k, v in state_dict.items(): 838 | if type(v) is torch.Tensor: 839 | state_dict[k] = v.to(dtype) 840 | 841 | # Convert the UNet2DConditionModel model. 842 | unet_config = create_unet_diffusers_config(v2) 843 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) 844 | 845 | unet = UNet2DConditionModel(**unet_config) 846 | info = unet.load_state_dict(converted_unet_checkpoint) 847 | 848 | 849 | # Convert the VAE model. 850 | vae_config = create_vae_diffusers_config() 851 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) 852 | 853 | vae = AutoencoderKL(**vae_config) 854 | info = vae.load_state_dict(converted_vae_checkpoint) 855 | 856 | 857 | # convert text_model 858 | if v2: 859 | converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) 860 | cfg = CLIPTextConfig( 861 | vocab_size=49408, 862 | hidden_size=1024, 863 | intermediate_size=4096, 864 | num_hidden_layers=23, 865 | num_attention_heads=16, 866 | max_position_embeddings=77, 867 | hidden_act="gelu", 868 | layer_norm_eps=1e-05, 869 | dropout=0.0, 870 | attention_dropout=0.0, 871 | initializer_range=0.02, 872 | initializer_factor=1.0, 873 | pad_token_id=1, 874 | bos_token_id=0, 875 | eos_token_id=2, 876 | model_type="clip_text_model", 877 | projection_dim=512, 878 | torch_dtype="float32", 879 | transformers_version="4.25.0.dev0", 880 | ) 881 | text_model = CLIPTextModel._from_config(cfg) 882 | info = text_model.load_state_dict(converted_text_encoder_checkpoint) 883 | else: 884 | converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) 885 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") 886 | info = text_model.load_state_dict(converted_text_encoder_checkpoint) 887 | 888 | 889 | return text_model, vae, unet 890 | 891 | 892 | def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): 893 | def convert_key(key): 894 | # position_idsの除去 895 | if ".position_ids" in key: 896 | return None 897 | 898 | # common 899 | key = key.replace("text_model.encoder.", "transformer.") 900 | key = key.replace("text_model.", "") 901 | if "layers" in key: 902 | # resblocks conversion 903 | key = key.replace(".layers.", ".resblocks.") 904 | if ".layer_norm" in key: 905 | key = key.replace(".layer_norm", ".ln_") 906 | elif ".mlp." in key: 907 | key = key.replace(".fc1.", ".c_fc.") 908 | key = key.replace(".fc2.", ".c_proj.") 909 | elif '.self_attn.out_proj' in key: 910 | key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") 911 | elif '.self_attn.' in key: 912 | key = None # 特殊なので後で処理する 913 | else: 914 | raise ValueError(f"unexpected key in DiffUsers model: {key}") 915 | elif '.position_embedding' in key: 916 | key = key.replace("embeddings.position_embedding.weight", "positional_embedding") 917 | elif '.token_embedding' in key: 918 | key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") 919 | elif 'final_layer_norm' in key: 920 | key = key.replace("final_layer_norm", "ln_final") 921 | return key 922 | 923 | keys = list(checkpoint.keys()) 924 | new_sd = {} 925 | for key in keys: 926 | new_key = convert_key(key) 927 | if new_key is None: 928 | continue 929 | new_sd[new_key] = checkpoint[key] 930 | 931 | # attnの変換 932 | for key in keys: 933 | if 'layers' in key and 'q_proj' in key: 934 | # 三つを結合 935 | key_q = key 936 | key_k = key.replace("q_proj", "k_proj") 937 | key_v = key.replace("q_proj", "v_proj") 938 | 939 | value_q = checkpoint[key_q] 940 | value_k = checkpoint[key_k] 941 | value_v = checkpoint[key_v] 942 | value = torch.cat([value_q, value_k, value_v]) 943 | 944 | new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") 945 | new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") 946 | new_sd[new_key] = value 947 | 948 | # 最後の層などを捏造するか 949 | if make_dummy_weights: 950 | 951 | keys = list(new_sd.keys()) 952 | for key in keys: 953 | if key.startswith("transformer.resblocks.22."): 954 | new_sd[key.replace(".22.", ".23.")] = new_sd[key] 955 | 956 | # Diffusersに含まれない重みを作っておく 957 | new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) 958 | new_sd['logit_scale'] = torch.tensor(1) 959 | 960 | return new_sd 961 | 962 | 963 | def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): 964 | if ckpt_path is not None: 965 | # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む 966 | checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) 967 | state_dict = checkpoint["state_dict"] 968 | strict = True 969 | else: 970 | # 新しく作る 971 | checkpoint = {} 972 | state_dict = {} 973 | strict = False 974 | 975 | def update_sd(prefix, sd): 976 | for k, v in sd.items(): 977 | key = prefix + k 978 | assert not strict or key in state_dict, f"Illegal key in save SD: {key}" 979 | if save_dtype is not None: 980 | v = v.detach().clone().to("cpu").to(save_dtype) 981 | state_dict[key] = v 982 | 983 | # Convert the UNet model 984 | unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) 985 | update_sd("model.diffusion_model.", unet_state_dict) 986 | 987 | # Convert the text encoder model 988 | if v2: 989 | make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる 990 | text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) 991 | update_sd("cond_stage_model.model.", text_enc_dict) 992 | else: 993 | text_enc_dict = text_encoder.state_dict() 994 | update_sd("cond_stage_model.transformer.", text_enc_dict) 995 | 996 | # Convert the VAE 997 | if vae is not None: 998 | vae_dict = convert_vae_state_dict(vae.state_dict()) 999 | update_sd("first_stage_model.", vae_dict) 1000 | 1001 | # Put together new checkpoint 1002 | key_count = len(state_dict.keys()) 1003 | new_ckpt = {'state_dict': state_dict} 1004 | 1005 | if 'epoch' in checkpoint: 1006 | epochs += checkpoint['epoch'] 1007 | if 'global_step' in checkpoint: 1008 | steps += checkpoint['global_step'] 1009 | 1010 | new_ckpt['epoch'] = epochs 1011 | new_ckpt['global_step'] = steps 1012 | 1013 | torch.save(new_ckpt, output_file) 1014 | 1015 | return key_count 1016 | 1017 | 1018 | def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None): 1019 | if vae is None: 1020 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") 1021 | pipeline = StableDiffusionPipeline( 1022 | unet=unet, 1023 | text_encoder=text_encoder, 1024 | vae=vae, 1025 | scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"), 1026 | tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"), 1027 | ) 1028 | pipeline.save_pretrained(output_dir) 1029 | 1030 | 1031 | 1032 | def convert(args): 1033 | # 引数を確認する 1034 | load_dtype = torch.float16 if args.fp16 else None 1035 | 1036 | save_dtype = None 1037 | if args.fp16: 1038 | save_dtype = torch.float16 1039 | elif args.bf16: 1040 | save_dtype = torch.bfloat16 1041 | elif args.float: 1042 | save_dtype = torch.float 1043 | 1044 | is_load_ckpt = os.path.isfile(args.model_to_load) 1045 | is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 1046 | 1047 | assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" 1048 | assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" 1049 | 1050 | # モデルを読み込む 1051 | msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) 1052 | 1053 | 1054 | if is_load_ckpt: 1055 | v2_model = args.v2 1056 | text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load) 1057 | else: 1058 | pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None) 1059 | text_encoder = pipe.text_encoder 1060 | vae = pipe.vae 1061 | unet = pipe.unet 1062 | 1063 | if args.v1 == args.v2: 1064 | # 自動判定する 1065 | v2_model = unet.config.cross_attention_dim == 1024 1066 | #print("checking model version: model is " + ('v2' if v2_model else 'v1')) 1067 | else: 1068 | v2_model = args.v1 1069 | 1070 | # 変換して保存する 1071 | msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" 1072 | 1073 | 1074 | if is_save_ckpt: 1075 | original_model = args.model_to_load if is_load_ckpt else None 1076 | key_count = save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet, 1077 | original_model, args.epoch, args.global_step, save_dtype, vae) 1078 | 1079 | else: 1080 | save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae) 1081 | 1082 | 1083 | 1084 | if __name__ == '__main__': 1085 | parser = argparse.ArgumentParser() 1086 | parser.add_argument("--v1", action='store_true', 1087 | help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む') 1088 | parser.add_argument("--v2", action='store_true', 1089 | help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む') 1090 | parser.add_argument("--fp16", action='store_true', 1091 | help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)') 1092 | parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)') 1093 | parser.add_argument("--float", action='store_true', 1094 | help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)') 1095 | parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値') 1096 | parser.add_argument("--global_step", type=int, default=0, 1097 | help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値') 1098 | parser.add_argument("--reference_model", type=str, default=None, 1099 | help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要") 1100 | 1101 | parser.add_argument("model_to_load", type=str, default=None, 1102 | help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ") 1103 | parser.add_argument("model_to_save", type=str, default=None, 1104 | help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存") 1105 | 1106 | args = parser.parse_args() 1107 | convert(args) 1108 | -------------------------------------------------------------------------------- /Dreambooth/det.py: -------------------------------------------------------------------------------- 1 | #Adapted from A1111 2 | import argparse 3 | import torch 4 | import open_clip 5 | import transformers.utils.hub 6 | from safetensors import safe_open 7 | import os 8 | import sys 9 | import wget 10 | from subprocess import call 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--MODEL_PATH", type=str) 14 | parser.add_argument("--from_safetensors", action='store_true') 15 | args = parser.parse_args() 16 | 17 | wget.download("https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dreambooth/ldm.zip") 18 | call('unzip ldm', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w')) 19 | call('rm ldm.zip', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w')) 20 | 21 | import ldm.modules.diffusionmodules.openaimodel 22 | import ldm.modules.encoders.modules 23 | 24 | class DisableInitialization: 25 | 26 | def __init__(self, disable_clip=True): 27 | self.replaced = [] 28 | self.disable_clip = disable_clip 29 | 30 | def replace(self, obj, field, func): 31 | original = getattr(obj, field, None) 32 | if original is None: 33 | return None 34 | 35 | self.replaced.append((obj, field, original)) 36 | setattr(obj, field, func) 37 | 38 | return original 39 | 40 | def __enter__(self): 41 | def do_nothing(*args, **kwargs): 42 | pass 43 | 44 | def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs): 45 | return self.create_model_and_transforms(*args, pretrained=None, **kwargs) 46 | 47 | def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): 48 | res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) 49 | res.name_or_path = pretrained_model_name_or_path 50 | return res 51 | 52 | def transformers_modeling_utils_load_pretrained_model(*args, **kwargs): 53 | args = args[0:3] + ('/', ) + args[4:] 54 | return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs) 55 | 56 | def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs): 57 | 58 | if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json': 59 | return None 60 | 61 | try: 62 | res = original(url, *args, local_files_only=True, **kwargs) 63 | if res is None: 64 | res = original(url, *args, local_files_only=False, **kwargs) 65 | return res 66 | except Exception as e: 67 | return original(url, *args, local_files_only=False, **kwargs) 68 | 69 | def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs): 70 | return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs) 71 | 72 | def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs): 73 | return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs) 74 | 75 | def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs): 76 | return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs) 77 | 78 | self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing) 79 | self.replace(torch.nn.init, '_no_grad_normal_', do_nothing) 80 | self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing) 81 | 82 | if self.disable_clip: 83 | self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) 84 | self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) 85 | self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) 86 | self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file) 87 | self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) 88 | self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) 89 | 90 | def __exit__(self, exc_type, exc_val, exc_tb): 91 | for obj, field, original in self.replaced: 92 | setattr(obj, field, original) 93 | 94 | self.replaced.clear() 95 | 96 | 97 | def vpar(state_dict): 98 | 99 | device = torch.device("cuda") 100 | 101 | with DisableInitialization(): 102 | unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( 103 | use_checkpoint=True, 104 | use_fp16=False, 105 | image_size=32, 106 | in_channels=4, 107 | out_channels=4, 108 | model_channels=320, 109 | attention_resolutions=[4, 2, 1], 110 | num_res_blocks=2, 111 | channel_mult=[1, 2, 4, 4], 112 | num_head_channels=64, 113 | use_spatial_transformer=True, 114 | use_linear_in_transformer=True, 115 | transformer_depth=1, 116 | context_dim=1024, 117 | legacy=False 118 | ) 119 | unet.eval() 120 | 121 | with torch.no_grad(): 122 | unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} 123 | unet.load_state_dict(unet_sd, strict=True) 124 | unet.to(device=device, dtype=torch.float) 125 | 126 | test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 127 | x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 128 | 129 | out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item() 130 | 131 | return out < -1 132 | 133 | 134 | def detect_version(sd): 135 | 136 | sys.stdout = open(os.devnull, 'w') 137 | 138 | sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) 139 | diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) 140 | 141 | if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: 142 | 143 | if vpar(sd): 144 | sys.stdout = sys.__stdout__ 145 | sd2_v=print("V2.1-768px") 146 | return sd2_v 147 | else: 148 | sys.stdout = sys.__stdout__ 149 | sd2=print("V2.1-512px") 150 | return sd2 151 | 152 | else: 153 | sys.stdout = sys.__stdout__ 154 | v1=print("1.5") 155 | return v1 156 | 157 | 158 | if args.from_safetensors: 159 | 160 | checkpoint = {} 161 | with safe_open(args.MODEL_PATH, framework="pt", device="cuda") as f: 162 | for key in f.keys(): 163 | checkpoint[key] = f.get_tensor(key) 164 | state_dict = checkpoint 165 | else: 166 | checkpoint = torch.load(args.MODEL_PATH, map_location="cuda") 167 | state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint 168 | 169 | detect_version(state_dict) 170 | 171 | call('rm -r ldm', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w')) 172 | -------------------------------------------------------------------------------- /Dreambooth/ldm.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dreambooth/ldm.zip -------------------------------------------------------------------------------- /Dreambooth/model_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableDiffusionPipeline", 3 | "_diffusers_version": "0.6.0", 4 | "scheduler": [ 5 | "diffusers", 6 | "PNDMScheduler" 7 | ], 8 | "text_encoder": [ 9 | "transformers", 10 | "CLIPTextModel" 11 | ], 12 | "tokenizer": [ 13 | "transformers", 14 | "CLIPTokenizer" 15 | ], 16 | "unet": [ 17 | "diffusers", 18 | "UNet2DConditionModel" 19 | ], 20 | "vae": [ 21 | "diffusers", 22 | "AutoencoderKL" 23 | ] 24 | } 25 | -------------------------------------------------------------------------------- /Dreambooth/refmdlz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/71575ff0676dc11ec3ebd032c7ad4ee6d871beff/Dreambooth/refmdlz -------------------------------------------------------------------------------- /Dreambooth/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "PNDMScheduler", 3 | "_diffusers_version": "0.6.0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "num_train_timesteps": 1000, 8 | "set_alpha_to_one": false, 9 | "skip_prk_steps": true, 10 | "steps_offset": 1, 11 | "trained_betas": null 12 | 13 | } 14 | -------------------------------------------------------------------------------- /Dreambooth/smart_crop.py: -------------------------------------------------------------------------------- 1 | #Based on A1111 cropping script 2 | import cv2 3 | import os 4 | from math import log, sqrt 5 | import numpy as np 6 | from PIL import Image, ImageDraw 7 | 8 | GREEN = "#0F0" 9 | BLUE = "#00F" 10 | RED = "#F00" 11 | 12 | 13 | def crop_image(im, size): 14 | 15 | def focal_point(im, settings): 16 | corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else [] 17 | entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else [] 18 | face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else [] 19 | 20 | pois = [] 21 | 22 | weight_pref_total = 0 23 | if len(corner_points) > 0: 24 | weight_pref_total += settings.corner_points_weight 25 | if len(entropy_points) > 0: 26 | weight_pref_total += settings.entropy_points_weight 27 | if len(face_points) > 0: 28 | weight_pref_total += settings.face_points_weight 29 | 30 | corner_centroid = None 31 | if len(corner_points) > 0: 32 | corner_centroid = centroid(corner_points) 33 | corner_centroid.weight = settings.corner_points_weight / weight_pref_total 34 | pois.append(corner_centroid) 35 | 36 | entropy_centroid = None 37 | if len(entropy_points) > 0: 38 | entropy_centroid = centroid(entropy_points) 39 | entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total 40 | pois.append(entropy_centroid) 41 | 42 | face_centroid = None 43 | if len(face_points) > 0: 44 | face_centroid = centroid(face_points) 45 | face_centroid.weight = settings.face_points_weight / weight_pref_total 46 | pois.append(face_centroid) 47 | 48 | average_point = poi_average(pois, settings) 49 | 50 | return average_point 51 | 52 | 53 | def image_face_points(im, settings): 54 | 55 | np_im = np.array(im) 56 | gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) 57 | 58 | tries = [ 59 | [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], 60 | [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], 61 | [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], 62 | [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], 63 | [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], 64 | [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], 65 | [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], 66 | [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] 67 | ] 68 | for t in tries: 69 | classifier = cv2.CascadeClassifier(t[0]) 70 | minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side 71 | try: 72 | faces = classifier.detectMultiScale(gray, scaleFactor=1.1, 73 | minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) 74 | except: 75 | continue 76 | 77 | if len(faces) > 0: 78 | rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] 79 | return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects] 80 | return [] 81 | 82 | 83 | def image_corner_points(im, settings): 84 | grayscale = im.convert("L") 85 | 86 | # naive attempt at preventing focal points from collecting at watermarks near the bottom 87 | gd = ImageDraw.Draw(grayscale) 88 | gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") 89 | 90 | np_im = np.array(grayscale) 91 | 92 | points = cv2.goodFeaturesToTrack( 93 | np_im, 94 | maxCorners=100, 95 | qualityLevel=0.04, 96 | minDistance=min(grayscale.width, grayscale.height)*0.06, 97 | useHarrisDetector=False, 98 | ) 99 | 100 | if points is None: 101 | return [] 102 | 103 | focal_points = [] 104 | for point in points: 105 | x, y = point.ravel() 106 | focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points))) 107 | 108 | return focal_points 109 | 110 | 111 | def image_entropy_points(im, settings): 112 | landscape = im.height < im.width 113 | portrait = im.height > im.width 114 | if landscape: 115 | move_idx = [0, 2] 116 | move_max = im.size[0] 117 | elif portrait: 118 | move_idx = [1, 3] 119 | move_max = im.size[1] 120 | else: 121 | return [] 122 | 123 | e_max = 0 124 | crop_current = [0, 0, settings.crop_width, settings.crop_height] 125 | crop_best = crop_current 126 | while crop_current[move_idx[1]] < move_max: 127 | crop = im.crop(tuple(crop_current)) 128 | e = image_entropy(crop) 129 | 130 | if (e > e_max): 131 | e_max = e 132 | crop_best = list(crop_current) 133 | 134 | crop_current[move_idx[0]] += 4 135 | crop_current[move_idx[1]] += 4 136 | 137 | x_mid = int(crop_best[0] + settings.crop_width/2) 138 | y_mid = int(crop_best[1] + settings.crop_height/2) 139 | 140 | return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)] 141 | 142 | 143 | def image_entropy(im): 144 | # greyscale image entropy 145 | # band = np.asarray(im.convert("L")) 146 | band = np.asarray(im.convert("1"), dtype=np.uint8) 147 | hist, _ = np.histogram(band, bins=range(0, 256)) 148 | hist = hist[hist > 0] 149 | return -np.log2(hist / hist.sum()).sum() 150 | 151 | def centroid(pois): 152 | x = [poi.x for poi in pois] 153 | y = [poi.y for poi in pois] 154 | return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois)) 155 | 156 | 157 | def poi_average(pois, settings): 158 | weight = 0.0 159 | x = 0.0 160 | y = 0.0 161 | for poi in pois: 162 | weight += poi.weight 163 | x += poi.x * poi.weight 164 | y += poi.y * poi.weight 165 | avg_x = round(weight and x / weight) 166 | avg_y = round(weight and y / weight) 167 | 168 | return PointOfInterest(avg_x, avg_y) 169 | 170 | 171 | def is_landscape(w, h): 172 | return w > h 173 | 174 | 175 | def is_portrait(w, h): 176 | return h > w 177 | 178 | 179 | def is_square(w, h): 180 | return w == h 181 | 182 | 183 | class PointOfInterest: 184 | def __init__(self, x, y, weight=1.0, size=10): 185 | self.x = x 186 | self.y = y 187 | self.weight = weight 188 | self.size = size 189 | 190 | def bounding(self, size): 191 | return [ 192 | self.x - size//2, 193 | self.y - size//2, 194 | self.x + size//2, 195 | self.y + size//2 196 | ] 197 | 198 | class Settings: 199 | def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5): 200 | self.crop_width = crop_width 201 | self.crop_height = crop_height 202 | self.corner_points_weight = corner_points_weight 203 | self.entropy_points_weight = entropy_points_weight 204 | self.face_points_weight = face_points_weight 205 | 206 | settings = Settings( 207 | crop_width = size, 208 | crop_height = size, 209 | face_points_weight = 0.9, 210 | entropy_points_weight = 0.15, 211 | corner_points_weight = 0.5, 212 | ) 213 | 214 | scale_by = 1 215 | if is_landscape(im.width, im.height): 216 | scale_by = settings.crop_height / im.height 217 | elif is_portrait(im.width, im.height): 218 | scale_by = settings.crop_width / im.width 219 | elif is_square(im.width, im.height): 220 | if is_square(settings.crop_width, settings.crop_height): 221 | scale_by = settings.crop_width / im.width 222 | elif is_landscape(settings.crop_width, settings.crop_height): 223 | scale_by = settings.crop_width / im.width 224 | elif is_portrait(settings.crop_width, settings.crop_height): 225 | scale_by = settings.crop_height / im.height 226 | 227 | im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) 228 | im_debug = im.copy() 229 | 230 | focus = focal_point(im_debug, settings) 231 | 232 | # take the focal point and turn it into crop coordinates that try to center over the focal 233 | # point but then get adjusted back into the frame 234 | y_half = int(settings.crop_height / 2) 235 | x_half = int(settings.crop_width / 2) 236 | 237 | x1 = focus.x - x_half 238 | if x1 < 0: 239 | x1 = 0 240 | elif x1 + settings.crop_width > im.width: 241 | x1 = im.width - settings.crop_width 242 | 243 | y1 = focus.y - y_half 244 | if y1 < 0: 245 | y1 = 0 246 | elif y1 + settings.crop_height > im.height: 247 | y1 = im.height - settings.crop_height 248 | 249 | x2 = x1 + settings.crop_width 250 | y2 = y1 + settings.crop_height 251 | 252 | crop = [x1, y1, x2, y2] 253 | 254 | results = [] 255 | 256 | results.append(im.crop(tuple(crop))) 257 | 258 | return results 259 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright © 2022 Ben 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Contact on [X](https://x.com/__TheBen) for highly advanced and optimized flux trainer or SDXL colab trainer (paid) 2 | 3 | # fast-stable-diffusion Notebooks, A1111 + DreamBooth 4 | Soon will stop maintaining paperspace notebooks. 5 | 6 |      Colab-AUTOMATIC1111                            Colab-Dreambooth
7 |                 8 | 9 | 10 | Dreambooth paper : https://dreambooth.github.io/ 11 | 12 | SD implementation by @XavierXiao : https://github.com/XavierXiao/Dreambooth-Stable-Diffusion 13 | -------------------------------------------------------------------------------- /fast_stable_diffusion_AUTOMATIC1111.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "47kV9o1Ni8GH" 7 | }, 8 | "source": [ 9 | "# **Colab Pro notebook from https://github.com/TheLastBen/fast-stable-diffusion**" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "cellView": "form", 17 | "id": "Y9EBc437WDOs" 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "#@markdown # Connect Google Drive\n", 22 | "from google.colab import drive\n", 23 | "from IPython.display import clear_output\n", 24 | "import ipywidgets as widgets\n", 25 | "import os\n", 26 | "\n", 27 | "def inf(msg, style, wdth): inf = widgets.Button(description=msg, disabled=True, button_style=style, layout=widgets.Layout(min_width=wdth));display(inf)\n", 28 | "Shared_Drive = \"\" #@param {type:\"string\"}\n", 29 | "#@markdown - Leave empty if you're not using a shared drive\n", 30 | "\n", 31 | "print(\"\u001b[0;33mConnecting...\")\n", 32 | "drive.mount('/content/gdrive')\n", 33 | "\n", 34 | "if Shared_Drive!=\"\" and os.path.exists(\"/content/gdrive/Shareddrives\"):\n", 35 | " mainpth=\"Shareddrives/\"+Shared_Drive\n", 36 | "else:\n", 37 | " mainpth=\"MyDrive\"\n", 38 | "\n", 39 | "clear_output()\n", 40 | "inf('\\u2714 Done','success', '50px')\n", 41 | "\n", 42 | "#@markdown ---" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": { 49 | "cellView": "form", 50 | "id": "CFWtw-6EPrKi" 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "#@markdown # Install/Update AUTOMATIC1111 repo\n", 55 | "from IPython.utils import capture\n", 56 | "from IPython.display import clear_output\n", 57 | "from subprocess import getoutput\n", 58 | "import ipywidgets as widgets\n", 59 | "import sys\n", 60 | "import fileinput\n", 61 | "import os\n", 62 | "import time\n", 63 | "import base64\n", 64 | "import requests\n", 65 | "from urllib.request import urlopen, Request\n", 66 | "from urllib.parse import urlparse, parse_qs, unquote\n", 67 | "from tqdm import tqdm\n", 68 | "import six\n", 69 | "\n", 70 | "\n", 71 | "blsaphemy=base64.b64decode((\"ZWJ1aQ==\").encode('ascii')).decode('ascii')\n", 72 | "\n", 73 | "if not os.path.exists(\"/content/gdrive\"):\n", 74 | " print('\u001b[1;31mGdrive not connected, using temporary colab storage ...')\n", 75 | " time.sleep(4)\n", 76 | " mainpth=\"MyDrive\"\n", 77 | " !mkdir -p /content/gdrive/$mainpth\n", 78 | " Shared_Drive=\"\"\n", 79 | "\n", 80 | "if Shared_Drive!=\"\" and not os.path.exists(\"/content/gdrive/Shareddrives\"):\n", 81 | " print('\u001b[1;31mShared drive not detected, using default MyDrive')\n", 82 | " mainpth=\"MyDrive\"\n", 83 | "\n", 84 | "with capture.capture_output() as cap:\n", 85 | " def inf(msg, style, wdth): inf = widgets.Button(description=msg, disabled=True, button_style=style, layout=widgets.Layout(min_width=wdth));display(inf)\n", 86 | " fgitclone = \"git clone --depth 1\"\n", 87 | " !git clone -q --depth 1 --branch main https://github.com/TheLastBen/diffusers\n", 88 | " %mkdir -p /content/gdrive/$mainpth/sd\n", 89 | " %cd /content/gdrive/$mainpth/sd\n", 90 | " !git clone -q --branch master https://github.com/AUTOMATIC1111/stable-diffusion-w$blsaphemy\n", 91 | " !mkdir -p /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/cache/\n", 92 | " os.environ['TRANSFORMERS_CACHE']=f\"/content/gdrive/{mainpth}/sd/stable-diffusion-w\"+blsaphemy+\"/cache\"\n", 93 | " os.environ['TORCH_HOME'] = f\"/content/gdrive/{mainpth}/sd/stable-diffusion-w\"+blsaphemy+\"/cache\"\n", 94 | " !mkdir -p /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/repositories\n", 95 | " !git clone https://github.com/AUTOMATIC1111/stable-diffusion-w$blsaphemy-assets /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/repositories/stable-diffusion-webui-assets\n", 96 | "\n", 97 | "with capture.capture_output() as cap:\n", 98 | " %cd /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/\n", 99 | " !git reset --hard\n", 100 | " !git checkout master\n", 101 | " time.sleep(1)\n", 102 | " !rm webui.sh\n", 103 | " !git pull\n", 104 | "clear_output()\n", 105 | "inf('\\u2714 Done','success', '50px')\n", 106 | "\n", 107 | "#@markdown ---" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": { 114 | "cellView": "form", 115 | "id": "ZGV_5H4xrOSp" 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "#@markdown # Requirements\n", 120 | "\n", 121 | "print('\u001b[1;32mInstalling requirements...')\n", 122 | "\n", 123 | "with capture.capture_output() as cap:\n", 124 | " %cd /content/\n", 125 | " !wget -q -i https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/main/Dependencies/A1111.txt\n", 126 | " !dpkg -i *.deb\n", 127 | " if not os.path.exists('/content/gdrive/'+mainpth+'/sd/stablediffusion'):\n", 128 | " !tar -C /content/gdrive/$mainpth --zstd -xf sd_mrep.tar.zst\n", 129 | " !tar -C / --zstd -xf gcolabdeps.tar.zst\n", 130 | " !rm *.deb | rm *.zst | rm *.txt\n", 131 | " if not os.path.exists('gdrive/'+mainpth+'/sd/libtcmalloc/libtcmalloc_minimal.so.4'):\n", 132 | " %env CXXFLAGS=-std=c++14\n", 133 | " !wget -q https://github.com/gperftools/gperftools/releases/download/gperftools-2.5/gperftools-2.5.tar.gz && tar zxf gperftools-2.5.tar.gz && mv gperftools-2.5 gperftools\n", 134 | " !wget -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/AUTOMATIC1111_files/Patch\n", 135 | " %cd /content/gperftools\n", 136 | " !patch -p1 < /content/Patch\n", 137 | " !./configure --enable-minimal --enable-libunwind --enable-frame-pointers --enable-dynamic-sized-delete-support --enable-sized-delete --enable-emergency-malloc; make -j4\n", 138 | " !mkdir -p /content/gdrive/$mainpth/sd/libtcmalloc && cp .libs/libtcmalloc*.so* /content/gdrive/$mainpth/sd/libtcmalloc\n", 139 | " %env LD_PRELOAD=/content/gdrive/$mainpth/sd/libtcmalloc/libtcmalloc_minimal.so.4\n", 140 | " %cd /content\n", 141 | " !rm *.tar.gz Patch && rm -r /content/gperftools\n", 142 | " else:\n", 143 | " %env LD_PRELOAD=/content/gdrive/$mainpth/sd/libtcmalloc/libtcmalloc_minimal.so.4\n", 144 | "\n", 145 | " !pip uninstall jax -y\n", 146 | " !pip install wandb==0.15.12 pydantic==1.10.2 numpy==1.24.3 controlnet_aux --no-deps -qq\n", 147 | " !pip install diffusers accelerate -U --no-deps -qq\n", 148 | " !rm -r /usr/local/lib/python3.11/dist-packages/tensorflow*\n", 149 | " os.environ['PYTHONWARNINGS'] = 'ignore'\n", 150 | " !sed -i 's@text = _formatwarnmsg(msg)@text =\\\"\\\"@g' /usr/lib/python3.11/warnings.py\n", 151 | " !sed -i 's@from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F401@@g' /usr/local/lib/python3.11/dist-packages/pytorch_lightning/loggers/__init__.py\n", 152 | " !sed -i 's@from .mailbox import ContextCancelledError@@g' /usr/local/lib/python3.11/dist-packages/wandb/sdk/lib/retry.py\n", 153 | " !sed -i 's@raise ContextCancelledError(\"retry timeout\")@print(\"retry timeout\")@g' /usr/local/lib/python3.11/dist-packages/wandb/sdk/lib/retry.py\n", 154 | "\n", 155 | "clear_output()\n", 156 | "inf('\\u2714 Done','success', '50px')\n", 157 | "\n", 158 | "#@markdown ---" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": { 165 | "cellView": "form", 166 | "id": "p4wj_txjP3TC" 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "#@markdown # Model Download/Load\n", 171 | "\n", 172 | "import gdown\n", 173 | "from gdown.download import get_url_from_gdrive_confirmation\n", 174 | "import re\n", 175 | "\n", 176 | "Use_Temp_Storage = False #@param {type:\"boolean\"}\n", 177 | "#@markdown - If not, make sure you have enough space on your gdrive\n", 178 | "\n", 179 | "#@markdown ---\n", 180 | "\n", 181 | "Model_Version = \"SDXL\" #@param [\"SDXL\", \"1.5\", \"v1.5 Inpainting\", \"V2.1-768px\"]\n", 182 | "\n", 183 | "#@markdown Or\n", 184 | "PATH_to_MODEL = \"\" #@param {type:\"string\"}\n", 185 | "#@markdown - Insert the full path of your custom model or to a folder containing multiple models\n", 186 | "\n", 187 | "#@markdown Or\n", 188 | "MODEL_LINK = \"\" #@param {type:\"string\"}\n", 189 | "\n", 190 | "\n", 191 | "def getsrc(url):\n", 192 | " parsed_url = urlparse(url)\n", 193 | " if parsed_url.netloc == 'civitai.com':\n", 194 | " src='civitai'\n", 195 | " elif parsed_url.netloc == 'drive.google.com':\n", 196 | " src='gdrive'\n", 197 | " elif parsed_url.netloc == 'huggingface.co':\n", 198 | " src='huggingface'\n", 199 | " else:\n", 200 | " src='others'\n", 201 | " return src\n", 202 | "\n", 203 | "src=getsrc(MODEL_LINK)\n", 204 | "\n", 205 | "def get_name(url, gdrive):\n", 206 | " if not gdrive:\n", 207 | " response = requests.get(url, allow_redirects=False)\n", 208 | " if \"Location\" in response.headers:\n", 209 | " redirected_url = response.headers[\"Location\"]\n", 210 | " quer = parse_qs(urlparse(redirected_url).query)\n", 211 | " if \"response-content-disposition\" in quer:\n", 212 | " disp_val = quer[\"response-content-disposition\"][0].split(\";\")\n", 213 | " for vals in disp_val:\n", 214 | " if vals.strip().startswith(\"filename=\"):\n", 215 | " filenm=unquote(vals.split(\"=\", 1)[1].strip())\n", 216 | " return filenm.replace(\"\\\"\",\"\")\n", 217 | " else:\n", 218 | " headers = {\"User-Agent\": \"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.95 Safari/537.36\"}\n", 219 | " lnk=\"https://drive.google.com/uc?id={id}&export=download\".format(id=url[url.find(\"/d/\")+3:url.find(\"/view\")])\n", 220 | " res = requests.session().get(lnk, headers=headers, stream=True, verify=True)\n", 221 | " res = requests.session().get(get_url_from_gdrive_confirmation(res.text), headers=headers, stream=True, verify=True)\n", 222 | " content_disposition = six.moves.urllib_parse.unquote(res.headers[\"Content-Disposition\"])\n", 223 | " filenm = re.search('attachment; filename=\"(.*?)\"', content_disposition).groups()[0]\n", 224 | " return filenm\n", 225 | "\n", 226 | "\n", 227 | "def dwn(url, dst, msg):\n", 228 | " file_size = None\n", 229 | " req = Request(url, headers={\"User-Agent\": \"torch.hub\"})\n", 230 | " u = urlopen(req)\n", 231 | " meta = u.info()\n", 232 | " if hasattr(meta, 'getheaders'):\n", 233 | " content_length = meta.getheaders(\"Content-Length\")\n", 234 | " else:\n", 235 | " content_length = meta.get_all(\"Content-Length\")\n", 236 | " if content_length is not None and len(content_length) > 0:\n", 237 | " file_size = int(content_length[0])\n", 238 | "\n", 239 | " with tqdm(total=file_size, disable=False, mininterval=0.5,\n", 240 | " bar_format=msg+' |{bar:20}| {percentage:3.0f}%') as pbar:\n", 241 | " with open(dst, \"wb\") as f:\n", 242 | " while True:\n", 243 | " buffer = u.read(8192)\n", 244 | " if len(buffer) == 0:\n", 245 | " break\n", 246 | " f.write(buffer)\n", 247 | " pbar.update(len(buffer))\n", 248 | " f.close()\n", 249 | "\n", 250 | "\n", 251 | "def sdmdls(ver, Use_Temp_Storage):\n", 252 | "\n", 253 | " if ver=='1.5':\n", 254 | " if Use_Temp_Storage:\n", 255 | " os.makedirs('/content/temp_models', exist_ok=True)\n", 256 | " model='/content/temp_models/v1-5-pruned-emaonly.safetensors'\n", 257 | " else:\n", 258 | " model='/content/gdrive/'+mainpth+'/sd/stable-diffusion-w'+blsaphemy+'/models/Stable-diffusion/v1-5-pruned-emaonly.safetensors'\n", 259 | " link='https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'\n", 260 | " elif ver=='V2.1-768px':\n", 261 | " if Use_Temp_Storage:\n", 262 | " os.makedirs('/content/temp_models', exist_ok=True)\n", 263 | " model='/content/temp_models/v2-1_768-ema-pruned.safetensors'\n", 264 | " else:\n", 265 | " model='/content/gdrive/'+mainpth+'/sd/stable-diffusion-w'+blsaphemy+'/models/Stable-diffusion/v2-1_768-ema-pruned.safetensors'\n", 266 | " link='https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.safetensors'\n", 267 | " elif ver=='v1.5 Inpainting':\n", 268 | " if Use_Temp_Storage:\n", 269 | " os.makedirs('/content/temp_models', exist_ok=True)\n", 270 | " model='/content/temp_models/sd-v1-5-inpainting.ckpt'\n", 271 | " else:\n", 272 | " model='/content/gdrive/'+mainpth+'/sd/stable-diffusion-w'+blsaphemy+'/models/Stable-diffusion/sd-v1-5-inpainting.ckpt'\n", 273 | " link='https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt'\n", 274 | " elif ver=='SDXL':\n", 275 | " if Use_Temp_Storage:\n", 276 | " os.makedirs('/content/temp_models', exist_ok=True)\n", 277 | " model='/content/temp_models/sd_xl_base_1.0.safetensors'\n", 278 | " else:\n", 279 | " model='/content/gdrive/'+mainpth+'/sd/stable-diffusion-w'+blsaphemy+'/models/Stable-diffusion/sd_xl_base_1.0.safetensors'\n", 280 | " link='https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors'\n", 281 | "\n", 282 | " if not os.path.exists(model):\n", 283 | " !gdown --fuzzy -O $model $link\n", 284 | " if os.path.exists(model):\n", 285 | " clear_output()\n", 286 | " inf('\\u2714 Done','success', '50px')\n", 287 | " else:\n", 288 | " inf('\\u2718 Something went wrong, try again','danger', \"250px\")\n", 289 | " else:\n", 290 | " clear_output()\n", 291 | " inf('\\u2714 Model already exists','primary', '300px')\n", 292 | "\n", 293 | " return model\n", 294 | "\n", 295 | "\n", 296 | "if (PATH_to_MODEL !=''):\n", 297 | " if os.path.exists(str(PATH_to_MODEL)):\n", 298 | " inf('\\u2714 Using the trained model.','success', '200px')\n", 299 | "\n", 300 | " else:\n", 301 | " while not os.path.exists(str(PATH_to_MODEL)):\n", 302 | " inf('\\u2718 Wrong path, use the colab file explorer to copy the path : ','danger', \"400px\")\n", 303 | " PATH_to_MODEL=input()\n", 304 | " if os.path.exists(str(PATH_to_MODEL)):\n", 305 | " inf('\\u2714 Using the custom model.','success', '200px')\n", 306 | "\n", 307 | " model=PATH_to_MODEL\n", 308 | "\n", 309 | "elif MODEL_LINK != \"\":\n", 310 | "\n", 311 | " if src=='civitai':\n", 312 | " modelname=get_name(MODEL_LINK, False)\n", 313 | " if Use_Temp_Storage:\n", 314 | " os.makedirs('/content/temp_models', exist_ok=True)\n", 315 | " model=f'/content/temp_models/{modelname}'\n", 316 | " else:\n", 317 | " model=f'/content/gdrive/{mainpth}/sd/stable-diffusion-w{blsaphemy}/models/Stable-diffusion/{modelname}'\n", 318 | " if not os.path.exists(model):\n", 319 | " dwn(MODEL_LINK, model, 'Downloading the custom model')\n", 320 | " clear_output()\n", 321 | " else:\n", 322 | " inf('\\u2714 Model already exists','primary', '300px')\n", 323 | " elif src=='gdrive':\n", 324 | " modelname=get_name(MODEL_LINK, True)\n", 325 | " if Use_Temp_Storage:\n", 326 | " os.makedirs('/content/temp_models', exist_ok=True)\n", 327 | " model=f'/content/temp_models/{modelname}'\n", 328 | " else:\n", 329 | " model=f'/content/gdrive/{mainpth}/sd/stable-diffusion-w{blsaphemy}/models/Stable-diffusion/{modelname}'\n", 330 | " if not os.path.exists(model):\n", 331 | " gdown.download(url=MODEL_LINK, output=model, quiet=False, fuzzy=True)\n", 332 | " clear_output()\n", 333 | " else:\n", 334 | " inf('\\u2714 Model already exists','primary', '300px')\n", 335 | " else:\n", 336 | " modelname=os.path.basename(MODEL_LINK)\n", 337 | " if Use_Temp_Storage:\n", 338 | " os.makedirs('/content/temp_models', exist_ok=True)\n", 339 | " model=f'/content/temp_models/{modelname}'\n", 340 | " else:\n", 341 | " model=f'/content/gdrive/{mainpth}/sd/stable-diffusion-w{blsaphemy}/models/Stable-diffusion/{modelname}'\n", 342 | " if not os.path.exists(model):\n", 343 | " gdown.download(url=MODEL_LINK, output=model, quiet=False, fuzzy=True)\n", 344 | " clear_output()\n", 345 | " else:\n", 346 | " inf('\\u2714 Model already exists','primary', '700px')\n", 347 | "\n", 348 | " if os.path.exists(model) and os.path.getsize(model) > 1810671599:\n", 349 | " inf('\\u2714 Model downloaded, using the custom model.','success', '300px')\n", 350 | " else:\n", 351 | " !rm model\n", 352 | " inf('\\u2718 Wrong link, check that the link is valid','danger', \"300px\")\n", 353 | "\n", 354 | "else:\n", 355 | " model=sdmdls(Model_Version, Use_Temp_Storage)\n", 356 | "\n", 357 | "#@markdown ---" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": null, 363 | "metadata": { 364 | "cellView": "form", 365 | "id": "Svx6Hx0iUPd1" 366 | }, 367 | "outputs": [], 368 | "source": [ 369 | "#@markdown # Download LoRA\n", 370 | "\n", 371 | "LoRA_LINK = \"\" #@param {type:\"string\"}\n", 372 | "\n", 373 | "if LoRA_LINK == \"\":\n", 374 | " inf('\\u2714 Nothing to do','primary', '200px')\n", 375 | "else:\n", 376 | " os.makedirs('/content/gdrive/'+mainpth+'/sd/stable-diffusion-w'+blsaphemy+'/models/Lora', exist_ok=True)\n", 377 | "\n", 378 | " src=getsrc(LoRA_LINK)\n", 379 | "\n", 380 | " if src=='civitai':\n", 381 | " modelname=get_name(LoRA_LINK, False)\n", 382 | " loramodel=f'/content/gdrive/{mainpth}/sd/stable-diffusion-w{blsaphemy}/models/Lora/{modelname}'\n", 383 | " if not os.path.exists(loramodel):\n", 384 | " dwn(LoRA_LINK, loramodel, 'Downloading the LoRA model '+modelname)\n", 385 | " clear_output()\n", 386 | " else:\n", 387 | " inf('\\u2714 Model already exists','primary', '200px')\n", 388 | " elif src=='gdrive':\n", 389 | " modelname=get_name(LoRA_LINK, True)\n", 390 | " loramodel=f'/content/gdrive/{mainpth}/sd/stable-diffusion-w{blsaphemy}/models/Lora/{modelname}'\n", 391 | " if not os.path.exists(loramodel):\n", 392 | " gdown.download(url=LoRA_LINK, output=loramodel, quiet=False, fuzzy=True)\n", 393 | " clear_output()\n", 394 | " else:\n", 395 | " inf('\\u2714 Model already exists','primary', '200px')\n", 396 | " else:\n", 397 | " modelname=os.path.basename(LoRA_LINK)\n", 398 | " loramodel=f'/content/gdrive/{mainpth}/sd/stable-diffusion-w{blsaphemy}/models/Lora/{modelname}'\n", 399 | " if not os.path.exists(loramodel):\n", 400 | " gdown.download(url=LoRA_LINK, output=loramodel, quiet=False, fuzzy=True)\n", 401 | " clear_output()\n", 402 | " else:\n", 403 | " inf('\\u2714 Model already exists','primary', '200px')\n", 404 | "\n", 405 | " if os.path.exists(loramodel) :\n", 406 | " inf('\\u2714 LoRA downloaded','success', '200px')\n", 407 | " else:\n", 408 | " inf('\\u2718 Wrong link, check that the link is valid','danger', \"300px\")\n", 409 | "\n", 410 | "#@markdown ---" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": null, 416 | "metadata": { 417 | "cellView": "form", 418 | "id": "zC3Rz1b2TBcB" 419 | }, 420 | "outputs": [], 421 | "source": [ 422 | "#@markdown # ControlNet\n", 423 | "from torch.hub import download_url_to_file\n", 424 | "from urllib.parse import urlparse\n", 425 | "import re\n", 426 | "from subprocess import run\n", 427 | "\n", 428 | "XL_Model = \"None\" #@param [ \"None\", \"All\", \"Canny\", \"Depth\", \"Sketch\", \"OpenPose\", \"Recolor\"]\n", 429 | "\n", 430 | "v1_Model = \"None\" #@param [ \"None\", \"All (21GB)\", \"Canny\", \"Depth\", \"Lineart\", \"MLSD\", \"Normal\", \"OpenPose\", \"Scribble\", \"Seg\", \"ip2p\", \"Shuffle\", \"Inpaint\", \"Softedge\", \"Lineart_Anime\", \"Tile\", \"T2iadapter_Models\"]\n", 431 | "\n", 432 | "v2_Model = \"None\" #@param [ \"None\", \"All\", \"Canny\", \"Depth\", \"HED\", \"OpenPose\", \"Scribble\"]\n", 433 | "\n", 434 | "#@markdown - Download/update ControlNet extension and its models\n", 435 | "\n", 436 | "def download(url, model_dir):\n", 437 | "\n", 438 | " filename = os.path.basename(urlparse(url).path)\n", 439 | " pth = os.path.abspath(os.path.join(model_dir, filename))\n", 440 | " if not os.path.exists(pth):\n", 441 | " print('Downloading: '+os.path.basename(url))\n", 442 | " download_url_to_file(url, pth, hash_prefix=None, progress=True)\n", 443 | " else:\n", 444 | " print(f\"\u001b[1;32mThe model {filename} already exists\u001b[0m\")\n", 445 | "\n", 446 | "\n", 447 | "Canny='https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/diffusers_xl_canny_mid.safetensors'\n", 448 | "Depth='https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/diffusers_xl_depth_mid.safetensors'\n", 449 | "Sketch='https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/sai_xl_sketch_256lora.safetensors'\n", 450 | "OpenPose='https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/thibaud_xl_openpose_256lora.safetensors'\n", 451 | "Recolor='https://huggingface.co/lllyasviel/sd_control_collection/resolve/main/sai_xl_recolor_128lora.safetensors'\n", 452 | "\n", 453 | "\n", 454 | "with capture.capture_output() as cap:\n", 455 | " %cd /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/extensions\n", 456 | " if not os.path.exists('sd-w'+blsaphemy+'-controlnet'):\n", 457 | " !git clone https://github.com/Mikubill/sd-w$blsaphemy-controlnet.git\n", 458 | " %cd /content\n", 459 | " else:\n", 460 | " %cd sd-w$blsaphemy-controlnet\n", 461 | " !git reset --hard\n", 462 | " !git pull\n", 463 | " %cd /content\n", 464 | "\n", 465 | "mdldir='/content/gdrive/'+mainpth+'/sd/stable-diffusion-w'+blsaphemy+'/extensions/sd-w'+blsaphemy+'-controlnet/models'\n", 466 | "for filename in os.listdir(mdldir):\n", 467 | " if \"_sd14v1\" in filename:\n", 468 | " renamed = re.sub(\"_sd14v1\", \"-fp16\", filename)\n", 469 | " os.rename(os.path.join(mdldir, filename), os.path.join(mdldir, renamed))\n", 470 | "\n", 471 | "!wget -q -O CN_models.txt https://github.com/TheLastBen/fast-stable-diffusion/raw/main/AUTOMATIC1111_files/CN_models.txt\n", 472 | "!wget -q -O CN_models_v2.txt https://github.com/TheLastBen/fast-stable-diffusion/raw/main/AUTOMATIC1111_files/CN_models_v2.txt\n", 473 | "!wget -q -O CN_models_XL.txt https://github.com/TheLastBen/fast-stable-diffusion/raw/main/AUTOMATIC1111_files/CN_models_XL.txt\n", 474 | "\n", 475 | "\n", 476 | "with open(\"CN_models.txt\", 'r') as f:\n", 477 | " mdllnk = f.read().splitlines()\n", 478 | "with open(\"CN_models_v2.txt\", 'r') as d:\n", 479 | " mdllnk_v2 = d.read().splitlines()\n", 480 | "with open(\"CN_models_XL.txt\", 'r') as d:\n", 481 | " mdllnk_XL = d.read().splitlines()\n", 482 | "\n", 483 | "!rm CN_models.txt CN_models_v2.txt CN_models_XL.txt\n", 484 | "\n", 485 | "\n", 486 | "if XL_Model == \"All\":\n", 487 | " for lnk_XL in mdllnk_XL:\n", 488 | " download(lnk_XL, mdldir)\n", 489 | " clear_output()\n", 490 | " inf('\\u2714 Done','success', '50px')\n", 491 | "\n", 492 | "elif XL_Model == \"None\":\n", 493 | " pass\n", 494 | " clear_output()\n", 495 | " inf('\\u2714 Done','success', '50px')\n", 496 | "\n", 497 | "else:\n", 498 | " download(globals()[XL_Model], mdldir)\n", 499 | " clear_output()\n", 500 | " inf('\\u2714 Done','success', '50px')\n", 501 | "\n", 502 | "\n", 503 | "Canny='https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth'\n", 504 | "Depth='https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1p_sd15_depth.pth'\n", 505 | "Lineart='https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_lineart.pth'\n", 506 | "MLSD='https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_mlsd.pth'\n", 507 | "Normal='https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_normalbae.pth'\n", 508 | "OpenPose='https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_openpose.pth'\n", 509 | "Scribble='https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_scribble.pth'\n", 510 | "Seg='https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_seg.pth'\n", 511 | "ip2p='https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11e_sd15_ip2p.pth'\n", 512 | "Shuffle='https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11e_sd15_shuffle.pth'\n", 513 | "Inpaint='https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_inpaint.pth'\n", 514 | "Softedge='https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_softedge.pth'\n", 515 | "Lineart_Anime='https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15s2_lineart_anime.pth'\n", 516 | "Tile='https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth'\n", 517 | "\n", 518 | "\n", 519 | "with capture.capture_output() as cap:\n", 520 | " cfgnames=[os.path.basename(url).split('.')[0]+'.yaml' for url in mdllnk_v2]\n", 521 | " %cd /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/extensions/sd-w$blsaphemy-controlnet/models\n", 522 | " for name in cfgnames:\n", 523 | " run(['cp', 'cldm_v21.yaml', name])\n", 524 | " %cd /content\n", 525 | "\n", 526 | "if v1_Model == \"All (21GB)\":\n", 527 | " for lnk in mdllnk:\n", 528 | " download(lnk, mdldir)\n", 529 | " clear_output()\n", 530 | "\n", 531 | "elif v1_Model == \"T2iadapter_Models\":\n", 532 | " mdllnk=list(filter(lambda x: 't2i' in x, mdllnk))\n", 533 | " for lnk in mdllnk:\n", 534 | " download(lnk, mdldir)\n", 535 | " clear_output()\n", 536 | "\n", 537 | "elif v1_Model == \"None\":\n", 538 | " pass\n", 539 | " clear_output()\n", 540 | "\n", 541 | "else:\n", 542 | " download(globals()[v1_Model], mdldir)\n", 543 | " clear_output()\n", 544 | "\n", 545 | "Canny='https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_canny.safetensors'\n", 546 | "Depth='https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_depth.safetensors'\n", 547 | "HED='https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_hed.safetensors'\n", 548 | "OpenPose='https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_openposev2.safetensors'\n", 549 | "Scribble='https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_scribble.safetensors'\n", 550 | "\n", 551 | "\n", 552 | "if v2_Model == \"All\":\n", 553 | " for lnk_v2 in mdllnk_v2:\n", 554 | " download(lnk_v2, mdldir)\n", 555 | " clear_output()\n", 556 | " inf('\\u2714 Done','success', '50px')\n", 557 | "\n", 558 | "elif v2_Model == \"None\":\n", 559 | " pass\n", 560 | " clear_output()\n", 561 | " inf('\\u2714 Done','success', '50px')\n", 562 | "\n", 563 | "else:\n", 564 | " download(globals()[v2_Model], mdldir)\n", 565 | " clear_output()\n", 566 | " inf('\\u2714 Done','success', '50px')\n", 567 | "\n", 568 | " #@markdown ---" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": null, 574 | "metadata": { 575 | "cellView": "form", 576 | "id": "PjzwxTkPSPHf" 577 | }, 578 | "outputs": [], 579 | "source": [ 580 | "#@markdown # Start Stable-Diffusion\n", 581 | "from IPython.utils import capture\n", 582 | "import time\n", 583 | "import sys\n", 584 | "import fileinput\n", 585 | "from pyngrok import ngrok, conf\n", 586 | "import re\n", 587 | "\n", 588 | "\n", 589 | "Ngrok_token = \"\" #@param {type:\"string\"}\n", 590 | "\n", 591 | "#@markdown - Input your ngrok token if you want to use ngrok server\n", 592 | "\n", 593 | "User = \"\" #@param {type:\"string\"}\n", 594 | "Password= \"\" #@param {type:\"string\"}\n", 595 | "#@markdown - Add credentials to your Gradio interface (optional)\n", 596 | "\n", 597 | "auth=f\"--gradio-auth {User}:{Password}\"\n", 598 | "if User ==\"\" or Password==\"\":\n", 599 | " auth=\"\"\n", 600 | "\n", 601 | "\n", 602 | "with capture.capture_output() as cap:\n", 603 | " %cd /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/modules/\n", 604 | " !wget -q -O extras.py https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-w$blsaphemy/master/modules/extras.py\n", 605 | " !wget -q -O sd_models.py https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-w$blsaphemy/master/modules/sd_models.py\n", 606 | " !wget -q -O /usr/local/lib/python3.11/dist-packages/gradio/blocks.py https://raw.githubusercontent.com/TheLastBen/fast-stable-diffusion/main/AUTOMATIC1111_files/blocks.py\n", 607 | " %cd /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/\n", 608 | " \n", 609 | " !sed -i 's@shared.opts.data\\[\"sd_model_checkpoint\"] = checkpoint_info.title@shared.opts.data\\[\"sd_model_checkpoint\"] = checkpoint_info.title;model.half()@' /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/modules/sd_models.py\n", 610 | " #!sed -i 's@ui.create_ui().*@ui.create_ui();shared.demo.queue(concurrency_count=999999,status_update_rate=0.1)@' /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/webui.py\n", 611 | " !sed -i \"s@map_location='cpu'@map_location='cuda'@\" /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/modules/extras.py\n", 612 | "\n", 613 | " !sed -i 's@possible_sd_paths =.*@possible_sd_paths = [\\\"/content/gdrive/{mainpth}/sd/stablediffusion\\\"]@' /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/modules/paths.py\n", 614 | " !sed -i 's@\\.\\.\\/@src/@g' /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/modules/paths.py\n", 615 | " !sed -i 's@src/generative-models@generative-models@g' /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/modules/paths.py\n", 616 | "\n", 617 | " !sed -i 's@print(\\\"No module.*@@' /content/gdrive/$mainpth/sd/stablediffusion/ldm/modules/diffusionmodules/model.py\n", 618 | " !sed -i 's@\\[\"sd_model_checkpoint\"\\]@\\[\"sd_model_checkpoint\", \"sd_vae\", \"CLIP_stop_at_last_layers\", \"inpainting_mask_weight\", \"initial_noise_multiplier\"\\]@g' /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/modules/shared.py\n", 619 | "\n", 620 | "share=''\n", 621 | "if Ngrok_token!=\"\":\n", 622 | " ngrok.kill()\n", 623 | " srv=ngrok.connect(7860, pyngrok_config=conf.PyngrokConfig(auth_token=Ngrok_token) , bind_tls=True).public_url\n", 624 | "\n", 625 | " for line in fileinput.input('/usr/local/lib/python3.11/dist-packages/gradio/blocks.py', inplace=True):\n", 626 | " if line.strip().startswith('self.server_name ='):\n", 627 | " line = f' self.server_name = \"{srv[8:]}\"\\n'\n", 628 | " if line.strip().startswith('self.protocol = \"https\"'):\n", 629 | " line = ' self.protocol = \"https\"\\n'\n", 630 | " if line.strip().startswith('if self.local_url.startswith(\"https\") or self.is_colab'):\n", 631 | " line = ''\n", 632 | " if line.strip().startswith('else \"http\"'):\n", 633 | " line = ''\n", 634 | " sys.stdout.write(line)\n", 635 | "else:\n", 636 | " share='--share'\n", 637 | "\n", 638 | "ckptdir=''\n", 639 | "if os.path.exists('/content/temp_models'):\n", 640 | " ckptdir='--ckpt-dir /content/temp_models'\n", 641 | "\n", 642 | "try:\n", 643 | " model\n", 644 | " if os.path.isfile(model):\n", 645 | " !python /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/webui.py $share --api --disable-safe-unpickle --enable-insecure-extension-access --no-download-sd-model --no-half-vae --ckpt \"$model\" --xformers $auth --disable-console-progressbars --skip-version-check $ckptdir\n", 646 | " else:\n", 647 | " !python /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/webui.py $share --api --disable-safe-unpickle --enable-insecure-extension-access --no-download-sd-model --no-half-vae --ckpt-dir \"$model\" --xformers $auth --disable-console-progressbars --skip-version-check\n", 648 | "except:\n", 649 | " !python /content/gdrive/$mainpth/sd/stable-diffusion-w$blsaphemy/webui.py $share --api --disable-safe-unpickle --enable-insecure-extension-access --no-download-sd-model --no-half-vae --xformers $auth --disable-console-progressbars --skip-version-check $ckptdir" 650 | ] 651 | } 652 | ], 653 | "metadata": { 654 | "accelerator": "GPU", 655 | "colab": { 656 | "provenance": [] 657 | }, 658 | "gpuClass": "standard", 659 | "kernelspec": { 660 | "display_name": "Python 3", 661 | "name": "python3" 662 | }, 663 | "language_info": { 664 | "name": "python" 665 | } 666 | }, 667 | "nbformat": 4, 668 | "nbformat_minor": 0 669 | } 670 | --------------------------------------------------------------------------------