├── .gitattributes ├── .github └── FUNDING.yml ├── .gitignore ├── AUTOMATIC1111_files ├── CN_models.txt ├── CN_models_v2.txt ├── blocks.py ├── frozen_dict.py └── paths.py ├── Dependencies ├── 1libunwind-dev_1.2.1-9ubuntu0.1_amd64.deb ├── A1111.txt ├── aptdeps.txt ├── cloudflared-linux-amd64.deb ├── dbdeps.txt ├── git-lfs_2.3.4-1_amd64.deb ├── google-perftools_2.5-2.2ubuntu3_all.deb ├── libc-ares2_1.15.0-1ubuntu0.2_amd64.deb ├── libgoogle-perftools-dev_2.5-2.2ubuntu3_amd64.deb ├── libgoogle-perftools4_2.5-2.2ubuntu3_amd64.deb ├── libtcmalloc-minimal4_2.5-2.2ubuntu3_amd64.deb ├── libzaria2-0_1.35.0-1build1_amd64.deb ├── man-db_2.9.1-1_amd64.deb ├── rename_1.10-1_all.deb ├── rnpd_deps.txt ├── rnpddeps.txt ├── unzip_6.0-25ubuntu1.1_amd64.deb ├── zaria2_1.35.0-1build1_amd64.deb ├── zip_3.0-11build1_amd64.deb └── zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb ├── Dreambooth ├── 1.jpg ├── 2.png ├── 3.JPG ├── 4.jpg ├── blocks.py ├── convertodiffv1.py ├── convertodiffv2-768.py ├── convertodiffv2.py ├── convertosd.py ├── convertosdv2.py ├── det.py ├── hub.py ├── ldm.zip ├── model_index.json ├── refmdlz ├── scheduler_config.json └── smart_crop.py ├── LICENSE ├── README.md ├── fast-DreamBooth.ipynb ├── fast_stable_diffusion_AUTOMATIC1111.ipynb └── precompiled ├── README.md └── T4 ├── T4 └── xfr /.gitattributes: -------------------------------------------------------------------------------- 1 | Dependencies/db_deps.tar.zst filter=lfs diff=lfs merge=lfs -text 2 | A1111_dep.tar.zst filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: thelastben 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /AUTOMATIC1111_files/CN_models.txt: -------------------------------------------------------------------------------- 1 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth 2 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1p_sd15_depth.pth 3 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_lineart.pth 4 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_mlsd.pth 5 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_normalbae.pth 6 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_openpose.pth 7 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_scribble.pth 8 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_seg.pth 9 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11e_sd15_ip2p.pth 10 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11e_sd15_shuffle.pth 11 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_inpaint.pth 12 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_softedge.pth 13 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15s2_lineart_anime.pth 14 | https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11u_sd15_tile.pth 15 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_keypose-fp16.safetensors 16 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_seg-fp16.safetensors 17 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_sketch-fp16.safetensors 18 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_depth-fp16.safetensors 19 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_canny-fp16.safetensors 20 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_color-fp16.safetensors 21 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_style-fp16.safetensors 22 | https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/t2iadapter_openpose-fp16.safetensors 23 | -------------------------------------------------------------------------------- /AUTOMATIC1111_files/CN_models_v2.txt: -------------------------------------------------------------------------------- 1 | https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_canny.safetensors 2 | https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_depth.safetensors 3 | https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_hed.safetensors 4 | https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_openpose.safetensors 5 | https://huggingface.co/thibaud/controlnet-sd21/resolve/main/control_v11p_sd21_scribble.safetensors 6 | -------------------------------------------------------------------------------- /AUTOMATIC1111_files/frozen_dict.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Frozen Dictionary.""" 16 | 17 | import collections 18 | from typing import Any, TypeVar, Mapping, Dict, Tuple, Union, Hashable 19 | 20 | from flax import serialization 21 | import jax 22 | 23 | 24 | class FrozenKeysView(collections.abc.KeysView): 25 | """A wrapper for a more useful repr of the keys in a frozen dict.""" 26 | 27 | def __repr__(self): 28 | return f'frozen_dict_keys({list(self)})' 29 | 30 | 31 | class FrozenValuesView(collections.abc.ValuesView): 32 | """A wrapper for a more useful repr of the values in a frozen dict.""" 33 | 34 | def __repr__(self): 35 | return f'frozen_dict_values({list(self)})' 36 | 37 | 38 | K = TypeVar('K') 39 | V = TypeVar('V') 40 | 41 | 42 | def _indent(x, num_spaces): 43 | indent_str = ' ' * num_spaces 44 | lines = x.split('\n') 45 | assert not lines[-1] 46 | # skip the final line because it's empty and should not be indented. 47 | return '\n'.join(indent_str + line for line in lines[:-1]) + '\n' 48 | 49 | 50 | # TODO(ivyzheng): change to register_pytree_with_keys_class after JAX release. 51 | @jax.tree_util.register_pytree_node_class 52 | class FrozenDict(Mapping[K, V]): 53 | """An immutable variant of the Python dict.""" 54 | __slots__ = ('_dict', '_hash') 55 | 56 | def __init__(self, *args, __unsafe_skip_copy__=False, **kwargs): # pylint: disable=invalid-name 57 | # make sure the dict is as 58 | xs = dict(*args, **kwargs) 59 | if __unsafe_skip_copy__: 60 | self._dict = xs 61 | else: 62 | self._dict = _prepare_freeze(xs) 63 | 64 | self._hash = None 65 | 66 | def __getitem__(self, key): 67 | v = self._dict[key] 68 | if isinstance(v, dict): 69 | return FrozenDict(v) 70 | return v 71 | 72 | def __setitem__(self, key, value): 73 | raise ValueError('FrozenDict is immutable.') 74 | 75 | def __contains__(self, key): 76 | return key in self._dict 77 | 78 | def __iter__(self): 79 | return iter(self._dict) 80 | 81 | def __len__(self): 82 | return len(self._dict) 83 | 84 | def __repr__(self): 85 | return self.pretty_repr() 86 | 87 | def __reduce__(self): 88 | return FrozenDict, (self.unfreeze(),) 89 | 90 | def pretty_repr(self, num_spaces=4): 91 | """Returns an indented representation of the nested dictionary.""" 92 | def pretty_dict(x): 93 | if not isinstance(x, dict): 94 | return repr(x) 95 | rep = '' 96 | for key, val in x.items(): 97 | rep += f'{key}: {pretty_dict(val)},\n' 98 | if rep: 99 | return '{\n' + _indent(rep, num_spaces) + '}' 100 | else: 101 | return '{}' 102 | return f'FrozenDict({pretty_dict(self._dict)})' 103 | 104 | def __hash__(self): 105 | if self._hash is None: 106 | h = 0 107 | for key, value in self.items(): 108 | h ^= hash((key, value)) 109 | self._hash = h 110 | return self._hash 111 | 112 | def copy(self, add_or_replace: Mapping[K, V]) -> 'FrozenDict[K, V]': 113 | """Create a new FrozenDict with additional or replaced entries.""" 114 | return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type] 115 | 116 | def keys(self): 117 | return FrozenKeysView(self) 118 | 119 | def values(self): 120 | return FrozenValuesView(self) 121 | 122 | def items(self): 123 | for key in self._dict: 124 | yield (key, self[key]) 125 | 126 | def pop(self, key: K) -> Tuple['FrozenDict[K, V]', V]: 127 | """Create a new FrozenDict where one entry is removed. 128 | 129 | Example:: 130 | 131 | state, params = variables.pop('params') 132 | 133 | Args: 134 | key: the key to remove from the dict 135 | Returns: 136 | A pair with the new FrozenDict and the removed value. 137 | """ 138 | value = self[key] 139 | new_dict = dict(self._dict) 140 | new_dict.pop(key) 141 | new_self = type(self)(new_dict) 142 | return new_self, value 143 | 144 | def unfreeze(self) -> Dict[K, V]: 145 | """Unfreeze this FrozenDict. 146 | 147 | Returns: 148 | An unfrozen version of this FrozenDict instance. 149 | """ 150 | return unfreeze(self) 151 | 152 | # TODO(ivyzheng): remove this after JAX 0.4.6 release. 153 | def tree_flatten(self) -> Tuple[Tuple[Any, ...], Hashable]: 154 | """Flattens this FrozenDict. 155 | 156 | Returns: 157 | A flattened version of this FrozenDict instance. 158 | """ 159 | sorted_keys = sorted(self._dict) 160 | return tuple([self._dict[k] for k in sorted_keys]), tuple(sorted_keys) 161 | 162 | @classmethod 163 | def tree_unflatten(cls, keys, values): 164 | # data is already deep copied due to tree map mechanism 165 | # we can skip the deep copy in the constructor 166 | return cls({k: v for k, v in zip(keys, values)}, __unsafe_skip_copy__=True) 167 | 168 | 169 | #jax.tree_util.register_keypaths( 170 | # FrozenDict, lambda fd: tuple(jax.tree_util.DictKey(k) for k in sorted(fd)) 171 | #) 172 | 173 | 174 | def _prepare_freeze(xs: Any) -> Any: 175 | """Deep copy unfrozen dicts to make the dictionary FrozenDict safe.""" 176 | if isinstance(xs, FrozenDict): 177 | # we can safely ref share the internal state of a FrozenDict 178 | # because it is immutable. 179 | return xs._dict # pylint: disable=protected-access 180 | if not isinstance(xs, dict): 181 | # return a leaf as is. 182 | return xs 183 | # recursively copy dictionary to avoid ref sharing 184 | return {key: _prepare_freeze(val) for key, val in xs.items()} 185 | 186 | 187 | def freeze(xs: Mapping[Any, Any]) -> FrozenDict[Any, Any]: 188 | """Freeze a nested dict. 189 | 190 | Makes a nested `dict` immutable by transforming it into `FrozenDict`. 191 | 192 | Args: 193 | xs: Dictionary to freeze (a regualr Python dict). 194 | Returns: 195 | The frozen dictionary. 196 | """ 197 | return FrozenDict(xs) 198 | 199 | 200 | def unfreeze(x: Union[FrozenDict, Dict[str, Any]]) -> Dict[Any, Any]: 201 | """Unfreeze a FrozenDict. 202 | 203 | Makes a mutable copy of a `FrozenDict` mutable by transforming 204 | it into (nested) dict. 205 | 206 | Args: 207 | x: Frozen dictionary to unfreeze. 208 | Returns: 209 | The unfrozen dictionary (a regular Python dict). 210 | """ 211 | if isinstance(x, FrozenDict): 212 | # deep copy internal state of a FrozenDict 213 | # the dict branch would also work here but 214 | # it is much less performant because jax.tree_util.tree_map 215 | # uses an optimized C implementation. 216 | return jax.tree_util.tree_map(lambda y: y, x._dict) # type: ignore 217 | elif isinstance(x, dict): 218 | ys = {} 219 | for key, value in x.items(): 220 | ys[key] = unfreeze(value) 221 | return ys 222 | else: 223 | return x 224 | 225 | 226 | def copy(x: Union[FrozenDict, Dict[str, Any]], add_or_replace: Union[FrozenDict, Dict[str, Any]]) -> Union[FrozenDict, Dict[str, Any]]: 227 | """Create a new dict with additional and/or replaced entries. This is a utility 228 | function that can act on either a FrozenDict or regular dict and mimics the 229 | behavior of `FrozenDict.copy`. 230 | 231 | Example:: 232 | 233 | new_variables = copy(variables, {'additional_entries': 1}) 234 | 235 | Args: 236 | x: the dictionary to be copied and updated 237 | add_or_replace: dictionary of key-value pairs to add or replace in the dict x 238 | Returns: 239 | A new dict with the additional and/or replaced entries. 240 | """ 241 | 242 | if isinstance(x, FrozenDict): 243 | return x.copy(add_or_replace) 244 | elif isinstance(x, dict): 245 | new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x 246 | new_dict.update(add_or_replace) 247 | return new_dict 248 | raise TypeError(f'Expected FrozenDict or dict, got {type(x)}') 249 | 250 | 251 | def pop(x: Union[FrozenDict, Dict[str, Any]], key: str) -> Tuple[Union[FrozenDict, Dict[str, Any]], Any]: 252 | """Create a new dict where one entry is removed. This is a utility 253 | function for regular dicts that mimics the behavior of `FrozenDict.pop`. 254 | 255 | Example:: 256 | 257 | state, params = pop(variables, 'params') 258 | 259 | Args: 260 | x: the dictionary to remove the entry from 261 | key: the key to remove from the dict 262 | Returns: 263 | A pair with the new dict and the removed value. 264 | """ 265 | 266 | if isinstance(x, FrozenDict): 267 | return x.pop(key) 268 | elif isinstance(x, dict): 269 | new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x 270 | value = new_dict.pop(key) 271 | return new_dict, value 272 | raise TypeError(f'Expected FrozenDict or dict, got {type(x)}') 273 | 274 | 275 | def _frozen_dict_state_dict(xs): 276 | return {key: serialization.to_state_dict(value) for key, value in xs.items()} 277 | 278 | 279 | def _restore_frozen_dict(xs, states): 280 | diff = set(map(str, xs.keys())).difference(states.keys()) 281 | if diff: 282 | raise ValueError('The target dict keys and state dict keys do not match,' 283 | f' target dict contains keys {diff} which are not present in state dict ' 284 | f'at path {serialization.current_path()}') 285 | 286 | return FrozenDict( 287 | {key: serialization.from_state_dict(value, states[key], name=key) 288 | for key, value in xs.items()}) 289 | 290 | 291 | serialization.register_serialization_state( 292 | FrozenDict, 293 | _frozen_dict_state_dict, 294 | _restore_frozen_dict) 295 | -------------------------------------------------------------------------------- /AUTOMATIC1111_files/paths.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir 5 | import modules.safe 6 | 7 | # data_path = cmd_opts_pre.data 8 | sys.path.insert(0, script_path) 9 | 10 | # search for directory of stable diffusion in following places 11 | sd_path = None 12 | possible_sd_paths = [os.path.join(script_path, '/content/gdrive/MyDrive/sd/stablediffusion'), '.', os.path.dirname(script_path)] 13 | for possible_sd_path in possible_sd_paths: 14 | if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): 15 | sd_path = os.path.abspath(possible_sd_path) 16 | break 17 | 18 | assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) 19 | 20 | path_dirs = [ 21 | (sd_path, 'ldm', 'Stable Diffusion', []), 22 | (os.path.join(sd_path, 'src/taming-transformers'), 'taming', 'Taming Transformers', []), 23 | (os.path.join(sd_path, 'src/codeformer'), 'inference_codeformer.py', 'CodeFormer', []), 24 | (os.path.join(sd_path, 'src/blip'), 'models/blip.py', 'BLIP', []), 25 | (os.path.join(sd_path, 'src/k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), 26 | ] 27 | 28 | paths = {} 29 | 30 | for d, must_exist, what, options in path_dirs: 31 | must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist)) 32 | if not os.path.exists(must_exist_path): 33 | print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr) 34 | else: 35 | d = os.path.abspath(d) 36 | if "atstart" in options: 37 | sys.path.insert(0, d) 38 | else: 39 | sys.path.append(d) 40 | paths[what] = d 41 | 42 | class Prioritize: 43 | def __init__(self, name): 44 | self.name = name 45 | self.path = None 46 | 47 | def __enter__(self): 48 | self.path = sys.path.copy() 49 | sys.path = [paths[self.name]] + sys.path 50 | 51 | def __exit__(self, exc_type, exc_val, exc_tb): 52 | sys.path = self.path 53 | self.path = None 54 | -------------------------------------------------------------------------------- /Dependencies/1libunwind-dev_1.2.1-9ubuntu0.1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/1libunwind-dev_1.2.1-9ubuntu0.1_amd64.deb -------------------------------------------------------------------------------- /Dependencies/A1111.txt: -------------------------------------------------------------------------------- 1 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb 2 | https://huggingface.co/TheLastBen/dependencies/resolve/main/sd.tar.zst 3 | https://huggingface.co/TheLastBen/dependencies/resolve/main/sd_rep.tar.zst 4 | https://huggingface.co/TheLastBen/dependencies/resolve/main/gcolabdeps.tar.zst 5 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/1libunwind-dev_1.2.1-9ubuntu0.1_amd64.deb 6 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/google-perftools_2.5-2.2ubuntu3_all.deb 7 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/libgoogle-perftools-dev_2.5-2.2ubuntu3_amd64.deb 8 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/libgoogle-perftools4_2.5-2.2ubuntu3_amd64.deb 9 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/libtcmalloc-minimal4_2.5-2.2ubuntu3_amd64.deb 10 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/cloudflared-linux-amd64.deb 11 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/libc-ares2_1.15.0-1ubuntu0.2_amd64.deb 12 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/libzaria2-0_1.35.0-1build1_amd64.deb 13 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/man-db_2.9.1-1_amd64.deb 14 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zaria2_1.35.0-1build1_amd64.deb 15 | -------------------------------------------------------------------------------- /Dependencies/aptdeps.txt: -------------------------------------------------------------------------------- 1 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/git-lfs_2.3.4-1_amd64.deb 2 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/rename_1.10-1_all.deb 3 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb 4 | -------------------------------------------------------------------------------- /Dependencies/cloudflared-linux-amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/cloudflared-linux-amd64.deb -------------------------------------------------------------------------------- /Dependencies/dbdeps.txt: -------------------------------------------------------------------------------- 1 | https://huggingface.co/TheLastBen/dependencies/resolve/main/gcolabdeps.tar.zst 2 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb 3 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/1libunwind-dev_1.2.1-9ubuntu0.1_amd64.deb 4 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/google-perftools_2.5-2.2ubuntu3_all.deb 5 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/libgoogle-perftools-dev_2.5-2.2ubuntu3_amd64.deb 6 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/libgoogle-perftools4_2.5-2.2ubuntu3_amd64.deb 7 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/libtcmalloc-minimal4_2.5-2.2ubuntu3_amd64.deb 8 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/libc-ares2_1.15.0-1ubuntu0.2_amd64.deb 9 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/libzaria2-0_1.35.0-1build1_amd64.deb 10 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/man-db_2.9.1-1_amd64.deb 11 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zaria2_1.35.0-1build1_amd64.deb 12 | -------------------------------------------------------------------------------- /Dependencies/git-lfs_2.3.4-1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/git-lfs_2.3.4-1_amd64.deb -------------------------------------------------------------------------------- /Dependencies/google-perftools_2.5-2.2ubuntu3_all.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/google-perftools_2.5-2.2ubuntu3_all.deb -------------------------------------------------------------------------------- /Dependencies/libc-ares2_1.15.0-1ubuntu0.2_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/libc-ares2_1.15.0-1ubuntu0.2_amd64.deb -------------------------------------------------------------------------------- /Dependencies/libgoogle-perftools-dev_2.5-2.2ubuntu3_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/libgoogle-perftools-dev_2.5-2.2ubuntu3_amd64.deb -------------------------------------------------------------------------------- /Dependencies/libgoogle-perftools4_2.5-2.2ubuntu3_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/libgoogle-perftools4_2.5-2.2ubuntu3_amd64.deb -------------------------------------------------------------------------------- /Dependencies/libtcmalloc-minimal4_2.5-2.2ubuntu3_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/libtcmalloc-minimal4_2.5-2.2ubuntu3_amd64.deb -------------------------------------------------------------------------------- /Dependencies/libzaria2-0_1.35.0-1build1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/libzaria2-0_1.35.0-1build1_amd64.deb -------------------------------------------------------------------------------- /Dependencies/man-db_2.9.1-1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/man-db_2.9.1-1_amd64.deb -------------------------------------------------------------------------------- /Dependencies/rename_1.10-1_all.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/rename_1.10-1_all.deb -------------------------------------------------------------------------------- /Dependencies/rnpd_deps.txt: -------------------------------------------------------------------------------- 1 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb 2 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/git-lfs_2.3.4-1_amd64.deb 3 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/rename_1.10-1_all.deb 4 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zip_3.0-11build1_amd64.deb 5 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/unzip_6.0-25ubuntu1.1_amd64.deb 6 | https://huggingface.co/TheLastBen/dependencies/resolve/main/rnpd-310.tar.zst 7 | -------------------------------------------------------------------------------- /Dependencies/rnpddeps.txt: -------------------------------------------------------------------------------- 1 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb 2 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/git-lfs_2.3.4-1_amd64.deb 3 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/rename_1.10-1_all.deb 4 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/zip_3.0-11build1_amd64.deb 5 | https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dependencies/unzip_6.0-25ubuntu1.1_amd64.deb 6 | https://huggingface.co/TheLastBen/dependencies/resolve/main/rnpddeps.tar.zst 7 | -------------------------------------------------------------------------------- /Dependencies/unzip_6.0-25ubuntu1.1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/unzip_6.0-25ubuntu1.1_amd64.deb -------------------------------------------------------------------------------- /Dependencies/zaria2_1.35.0-1build1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/zaria2_1.35.0-1build1_amd64.deb -------------------------------------------------------------------------------- /Dependencies/zip_3.0-11build1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/zip_3.0-11build1_amd64.deb -------------------------------------------------------------------------------- /Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dependencies/zstd_1.4.4+dfsg-3ubuntu0.1_amd64.deb -------------------------------------------------------------------------------- /Dreambooth/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dreambooth/1.jpg -------------------------------------------------------------------------------- /Dreambooth/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dreambooth/2.png -------------------------------------------------------------------------------- /Dreambooth/3.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dreambooth/3.JPG -------------------------------------------------------------------------------- /Dreambooth/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dreambooth/4.jpg -------------------------------------------------------------------------------- /Dreambooth/convertodiffv1.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig 5 | from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel 6 | 7 | 8 | 9 | # DiffUsers版StableDiffusionのモデルパラメータ 10 | NUM_TRAIN_TIMESTEPS = 1000 11 | BETA_START = 0.00085 12 | BETA_END = 0.0120 13 | 14 | UNET_PARAMS_MODEL_CHANNELS = 320 15 | UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] 16 | UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] 17 | UNET_PARAMS_IMAGE_SIZE = 64 18 | UNET_PARAMS_IN_CHANNELS = 4 19 | UNET_PARAMS_OUT_CHANNELS = 4 20 | UNET_PARAMS_NUM_RES_BLOCKS = 2 21 | UNET_PARAMS_CONTEXT_DIM = 768 22 | UNET_PARAMS_NUM_HEADS = 8 23 | 24 | VAE_PARAMS_Z_CHANNELS = 4 25 | VAE_PARAMS_RESOLUTION = 512 26 | VAE_PARAMS_IN_CHANNELS = 3 27 | VAE_PARAMS_OUT_CH = 3 28 | VAE_PARAMS_CH = 128 29 | VAE_PARAMS_CH_MULT = [1, 2, 4, 4] 30 | VAE_PARAMS_NUM_RES_BLOCKS = 2 31 | 32 | # V2 33 | V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] 34 | V2_UNET_PARAMS_CONTEXT_DIM = 1024 35 | 36 | 37 | # region StableDiffusion->Diffusersの変換コード 38 | # convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0) 39 | 40 | 41 | def shave_segments(path, n_shave_prefix_segments=1): 42 | """ 43 | Removes segments. Positive values shave the first segments, negative shave the last segments. 44 | """ 45 | if n_shave_prefix_segments >= 0: 46 | return ".".join(path.split(".")[n_shave_prefix_segments:]) 47 | else: 48 | return ".".join(path.split(".")[:n_shave_prefix_segments]) 49 | 50 | 51 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): 52 | """ 53 | Updates paths inside resnets to the new naming scheme (local renaming) 54 | """ 55 | mapping = [] 56 | for old_item in old_list: 57 | new_item = old_item.replace("in_layers.0", "norm1") 58 | new_item = new_item.replace("in_layers.2", "conv1") 59 | 60 | new_item = new_item.replace("out_layers.0", "norm2") 61 | new_item = new_item.replace("out_layers.3", "conv2") 62 | 63 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") 64 | new_item = new_item.replace("skip_connection", "conv_shortcut") 65 | 66 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 67 | 68 | mapping.append({"old": old_item, "new": new_item}) 69 | 70 | return mapping 71 | 72 | 73 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): 74 | """ 75 | Updates paths inside resnets to the new naming scheme (local renaming) 76 | """ 77 | mapping = [] 78 | for old_item in old_list: 79 | new_item = old_item 80 | 81 | new_item = new_item.replace("nin_shortcut", "conv_shortcut") 82 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 83 | 84 | mapping.append({"old": old_item, "new": new_item}) 85 | 86 | return mapping 87 | 88 | 89 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): 90 | """ 91 | Updates paths inside attentions to the new naming scheme (local renaming) 92 | """ 93 | mapping = [] 94 | for old_item in old_list: 95 | new_item = old_item 96 | 97 | # new_item = new_item.replace('norm.weight', 'group_norm.weight') 98 | # new_item = new_item.replace('norm.bias', 'group_norm.bias') 99 | 100 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') 101 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') 102 | 103 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 104 | 105 | mapping.append({"old": old_item, "new": new_item}) 106 | 107 | return mapping 108 | 109 | 110 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): 111 | """ 112 | Updates paths inside attentions to the new naming scheme (local renaming) 113 | """ 114 | mapping = [] 115 | for old_item in old_list: 116 | new_item = old_item 117 | 118 | new_item = new_item.replace("norm.weight", "group_norm.weight") 119 | new_item = new_item.replace("norm.bias", "group_norm.bias") 120 | 121 | new_item = new_item.replace("q.weight", "query.weight") 122 | new_item = new_item.replace("q.bias", "query.bias") 123 | 124 | new_item = new_item.replace("k.weight", "key.weight") 125 | new_item = new_item.replace("k.bias", "key.bias") 126 | 127 | new_item = new_item.replace("v.weight", "value.weight") 128 | new_item = new_item.replace("v.bias", "value.bias") 129 | 130 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight") 131 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias") 132 | 133 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 134 | 135 | mapping.append({"old": old_item, "new": new_item}) 136 | 137 | return mapping 138 | 139 | 140 | def assign_to_checkpoint( 141 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None 142 | ): 143 | """ 144 | This does the final conversion step: take locally converted weights and apply a global renaming 145 | to them. It splits attention layers, and takes into account additional replacements 146 | that may arise. 147 | 148 | Assigns the weights to the new checkpoint. 149 | """ 150 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." 151 | 152 | # Splits the attention layers into three variables. 153 | if attention_paths_to_split is not None: 154 | for path, path_map in attention_paths_to_split.items(): 155 | old_tensor = old_checkpoint[path] 156 | channels = old_tensor.shape[0] // 3 157 | 158 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) 159 | 160 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 161 | 162 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) 163 | query, key, value = old_tensor.split(channels // num_heads, dim=1) 164 | 165 | checkpoint[path_map["query"]] = query.reshape(target_shape) 166 | checkpoint[path_map["key"]] = key.reshape(target_shape) 167 | checkpoint[path_map["value"]] = value.reshape(target_shape) 168 | 169 | for path in paths: 170 | new_path = path["new"] 171 | 172 | # These have already been assigned 173 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: 174 | continue 175 | 176 | # Global renaming happens here 177 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") 178 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") 179 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") 180 | 181 | if additional_replacements is not None: 182 | for replacement in additional_replacements: 183 | new_path = new_path.replace(replacement["old"], replacement["new"]) 184 | 185 | # proj_attn.weight has to be converted from conv 1D to linear 186 | if "proj_attn.weight" in new_path: 187 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] 188 | else: 189 | checkpoint[new_path] = old_checkpoint[path["old"]] 190 | 191 | 192 | def conv_attn_to_linear(checkpoint): 193 | keys = list(checkpoint.keys()) 194 | attn_keys = ["query.weight", "key.weight", "value.weight"] 195 | for key in keys: 196 | if ".".join(key.split(".")[-2:]) in attn_keys: 197 | if checkpoint[key].ndim > 2: 198 | checkpoint[key] = checkpoint[key][:, :, 0, 0] 199 | elif "proj_attn.weight" in key: 200 | if checkpoint[key].ndim > 2: 201 | checkpoint[key] = checkpoint[key][:, :, 0] 202 | 203 | 204 | def linear_transformer_to_conv(checkpoint): 205 | keys = list(checkpoint.keys()) 206 | tf_keys = ["proj_in.weight", "proj_out.weight"] 207 | for key in keys: 208 | if ".".join(key.split(".")[-2:]) in tf_keys: 209 | if checkpoint[key].ndim == 2: 210 | checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) 211 | 212 | 213 | def convert_ldm_unet_checkpoint(v2, checkpoint, config): 214 | """ 215 | Takes a state dict and a config, and returns a converted checkpoint. 216 | """ 217 | 218 | # extract state_dict for UNet 219 | unet_state_dict = {} 220 | unet_key = "model.diffusion_model." 221 | keys = list(checkpoint.keys()) 222 | for key in keys: 223 | if key.startswith(unet_key): 224 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) 225 | 226 | new_checkpoint = {} 227 | 228 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] 229 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] 230 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] 231 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] 232 | 233 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] 234 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] 235 | 236 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] 237 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] 238 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] 239 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] 240 | 241 | # Retrieves the keys for the input blocks only 242 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) 243 | input_blocks = { 244 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] 245 | for layer_id in range(num_input_blocks) 246 | } 247 | 248 | # Retrieves the keys for the middle blocks only 249 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) 250 | middle_blocks = { 251 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] 252 | for layer_id in range(num_middle_blocks) 253 | } 254 | 255 | # Retrieves the keys for the output blocks only 256 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) 257 | output_blocks = { 258 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] 259 | for layer_id in range(num_output_blocks) 260 | } 261 | 262 | for i in range(1, num_input_blocks): 263 | block_id = (i - 1) // (config["layers_per_block"] + 1) 264 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) 265 | 266 | resnets = [ 267 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key 268 | ] 269 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] 270 | 271 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict: 272 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( 273 | f"input_blocks.{i}.0.op.weight" 274 | ) 275 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( 276 | f"input_blocks.{i}.0.op.bias" 277 | ) 278 | 279 | paths = renew_resnet_paths(resnets) 280 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} 281 | assign_to_checkpoint( 282 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 283 | ) 284 | 285 | if len(attentions): 286 | paths = renew_attention_paths(attentions) 287 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} 288 | assign_to_checkpoint( 289 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 290 | ) 291 | 292 | resnet_0 = middle_blocks[0] 293 | attentions = middle_blocks[1] 294 | resnet_1 = middle_blocks[2] 295 | 296 | resnet_0_paths = renew_resnet_paths(resnet_0) 297 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) 298 | 299 | resnet_1_paths = renew_resnet_paths(resnet_1) 300 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) 301 | 302 | attentions_paths = renew_attention_paths(attentions) 303 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} 304 | assign_to_checkpoint( 305 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 306 | ) 307 | 308 | for i in range(num_output_blocks): 309 | block_id = i // (config["layers_per_block"] + 1) 310 | layer_in_block_id = i % (config["layers_per_block"] + 1) 311 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] 312 | output_block_list = {} 313 | 314 | for layer in output_block_layers: 315 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) 316 | if layer_id in output_block_list: 317 | output_block_list[layer_id].append(layer_name) 318 | else: 319 | output_block_list[layer_id] = [layer_name] 320 | 321 | if len(output_block_list) > 1: 322 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] 323 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] 324 | 325 | resnet_0_paths = renew_resnet_paths(resnets) 326 | paths = renew_resnet_paths(resnets) 327 | 328 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} 329 | assign_to_checkpoint( 330 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 331 | ) 332 | 333 | if ["conv.weight", "conv.bias"] in output_block_list.values(): 334 | index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) 335 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ 336 | f"output_blocks.{i}.{index}.conv.weight" 337 | ] 338 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ 339 | f"output_blocks.{i}.{index}.conv.bias" 340 | ] 341 | 342 | # Clear attentions as they have been attributed above. 343 | if len(attentions) == 2: 344 | attentions = [] 345 | 346 | if len(attentions): 347 | paths = renew_attention_paths(attentions) 348 | meta_path = { 349 | "old": f"output_blocks.{i}.1", 350 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", 351 | } 352 | assign_to_checkpoint( 353 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 354 | ) 355 | else: 356 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) 357 | for path in resnet_0_paths: 358 | old_path = ".".join(["output_blocks", str(i), path["old"]]) 359 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) 360 | 361 | new_checkpoint[new_path] = unet_state_dict[old_path] 362 | 363 | # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する 364 | if v2: 365 | linear_transformer_to_conv(new_checkpoint) 366 | 367 | return new_checkpoint 368 | 369 | 370 | def convert_ldm_vae_checkpoint(checkpoint, config): 371 | # extract state dict for VAE 372 | vae_state_dict = {} 373 | vae_key = "first_stage_model." 374 | keys = list(checkpoint.keys()) 375 | for key in keys: 376 | if key.startswith(vae_key): 377 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) 378 | # if len(vae_state_dict) == 0: 379 | # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict 380 | # vae_state_dict = checkpoint 381 | 382 | new_checkpoint = {} 383 | 384 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] 385 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] 386 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] 387 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] 388 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] 389 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] 390 | 391 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] 392 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] 393 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] 394 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] 395 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] 396 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] 397 | 398 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] 399 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] 400 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] 401 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] 402 | 403 | # Retrieves the keys for the encoder down blocks only 404 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) 405 | down_blocks = { 406 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) 407 | } 408 | 409 | # Retrieves the keys for the decoder up blocks only 410 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) 411 | up_blocks = { 412 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) 413 | } 414 | 415 | for i in range(num_down_blocks): 416 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] 417 | 418 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: 419 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( 420 | f"encoder.down.{i}.downsample.conv.weight" 421 | ) 422 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( 423 | f"encoder.down.{i}.downsample.conv.bias" 424 | ) 425 | 426 | paths = renew_vae_resnet_paths(resnets) 427 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} 428 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 429 | 430 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] 431 | num_mid_res_blocks = 2 432 | for i in range(1, num_mid_res_blocks + 1): 433 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] 434 | 435 | paths = renew_vae_resnet_paths(resnets) 436 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 437 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 438 | 439 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] 440 | paths = renew_vae_attention_paths(mid_attentions) 441 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 442 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 443 | conv_attn_to_linear(new_checkpoint) 444 | 445 | for i in range(num_up_blocks): 446 | block_id = num_up_blocks - 1 - i 447 | resnets = [ 448 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key 449 | ] 450 | 451 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: 452 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ 453 | f"decoder.up.{block_id}.upsample.conv.weight" 454 | ] 455 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ 456 | f"decoder.up.{block_id}.upsample.conv.bias" 457 | ] 458 | 459 | paths = renew_vae_resnet_paths(resnets) 460 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} 461 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 462 | 463 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] 464 | num_mid_res_blocks = 2 465 | for i in range(1, num_mid_res_blocks + 1): 466 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] 467 | 468 | paths = renew_vae_resnet_paths(resnets) 469 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 470 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 471 | 472 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] 473 | paths = renew_vae_attention_paths(mid_attentions) 474 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 475 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 476 | conv_attn_to_linear(new_checkpoint) 477 | return new_checkpoint 478 | 479 | 480 | def create_unet_diffusers_config(v2): 481 | """ 482 | Creates a config for the diffusers based on the config of the LDM model. 483 | """ 484 | # unet_params = original_config.model.params.unet_config.params 485 | 486 | block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] 487 | 488 | down_block_types = [] 489 | resolution = 1 490 | for i in range(len(block_out_channels)): 491 | block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" 492 | down_block_types.append(block_type) 493 | if i != len(block_out_channels) - 1: 494 | resolution *= 2 495 | 496 | up_block_types = [] 497 | for i in range(len(block_out_channels)): 498 | block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" 499 | up_block_types.append(block_type) 500 | resolution //= 2 501 | 502 | config = dict( 503 | sample_size=UNET_PARAMS_IMAGE_SIZE, 504 | in_channels=UNET_PARAMS_IN_CHANNELS, 505 | out_channels=UNET_PARAMS_OUT_CHANNELS, 506 | down_block_types=tuple(down_block_types), 507 | up_block_types=tuple(up_block_types), 508 | block_out_channels=tuple(block_out_channels), 509 | layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, 510 | cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, 511 | attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, 512 | ) 513 | 514 | return config 515 | 516 | 517 | def create_vae_diffusers_config(): 518 | """ 519 | Creates a config for the diffusers based on the config of the LDM model. 520 | """ 521 | # vae_params = original_config.model.params.first_stage_config.params.ddconfig 522 | # _ = original_config.model.params.first_stage_config.params.embed_dim 523 | block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] 524 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) 525 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) 526 | 527 | config = dict( 528 | sample_size=VAE_PARAMS_RESOLUTION, 529 | in_channels=VAE_PARAMS_IN_CHANNELS, 530 | out_channels=VAE_PARAMS_OUT_CH, 531 | down_block_types=tuple(down_block_types), 532 | up_block_types=tuple(up_block_types), 533 | block_out_channels=tuple(block_out_channels), 534 | latent_channels=VAE_PARAMS_Z_CHANNELS, 535 | layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, 536 | ) 537 | return config 538 | 539 | 540 | def convert_ldm_clip_checkpoint_v1(checkpoint): 541 | keys = list(checkpoint.keys()) 542 | text_model_dict = {} 543 | for key in keys: 544 | if key.startswith("cond_stage_model.transformer"): 545 | text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] 546 | return text_model_dict 547 | 548 | 549 | def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): 550 | # 嫌になるくらい違うぞ! 551 | def convert_key(key): 552 | if not key.startswith("cond_stage_model"): 553 | return None 554 | 555 | # common conversion 556 | key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") 557 | key = key.replace("cond_stage_model.model.", "text_model.") 558 | 559 | if "resblocks" in key: 560 | # resblocks conversion 561 | key = key.replace(".resblocks.", ".layers.") 562 | if ".ln_" in key: 563 | key = key.replace(".ln_", ".layer_norm") 564 | elif ".mlp." in key: 565 | key = key.replace(".c_fc.", ".fc1.") 566 | key = key.replace(".c_proj.", ".fc2.") 567 | elif '.attn.out_proj' in key: 568 | key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") 569 | elif '.attn.in_proj' in key: 570 | key = None # 特殊なので後で処理する 571 | else: 572 | raise ValueError(f"unexpected key in SD: {key}") 573 | elif '.positional_embedding' in key: 574 | key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") 575 | elif '.text_projection' in key: 576 | key = None # 使われない??? 577 | elif '.logit_scale' in key: 578 | key = None # 使われない??? 579 | elif '.token_embedding' in key: 580 | key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") 581 | elif '.ln_final' in key: 582 | key = key.replace(".ln_final", ".final_layer_norm") 583 | return key 584 | 585 | keys = list(checkpoint.keys()) 586 | new_sd = {} 587 | for key in keys: 588 | # remove resblocks 23 589 | if '.resblocks.23.' in key: 590 | continue 591 | new_key = convert_key(key) 592 | if new_key is None: 593 | continue 594 | new_sd[new_key] = checkpoint[key] 595 | 596 | # attnの変換 597 | for key in keys: 598 | if '.resblocks.23.' in key: 599 | continue 600 | if '.resblocks' in key and '.attn.in_proj_' in key: 601 | # 三つに分割 602 | values = torch.chunk(checkpoint[key], 3) 603 | 604 | key_suffix = ".weight" if "weight" in key else ".bias" 605 | key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") 606 | key_pfx = key_pfx.replace("_weight", "") 607 | key_pfx = key_pfx.replace("_bias", "") 608 | key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") 609 | new_sd[key_pfx + "q_proj" + key_suffix] = values[0] 610 | new_sd[key_pfx + "k_proj" + key_suffix] = values[1] 611 | new_sd[key_pfx + "v_proj" + key_suffix] = values[2] 612 | 613 | # position_idsの追加 614 | new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64) 615 | return new_sd 616 | 617 | # endregion 618 | 619 | 620 | # region Diffusers->StableDiffusion の変換コード 621 | # convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0) 622 | 623 | def conv_transformer_to_linear(checkpoint): 624 | keys = list(checkpoint.keys()) 625 | tf_keys = ["proj_in.weight", "proj_out.weight"] 626 | for key in keys: 627 | if ".".join(key.split(".")[-2:]) in tf_keys: 628 | if checkpoint[key].ndim > 2: 629 | checkpoint[key] = checkpoint[key][:, :, 0, 0] 630 | 631 | 632 | def convert_unet_state_dict_to_sd(v2, unet_state_dict): 633 | unet_conversion_map = [ 634 | # (stable-diffusion, HF Diffusers) 635 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), 636 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), 637 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), 638 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), 639 | ("input_blocks.0.0.weight", "conv_in.weight"), 640 | ("input_blocks.0.0.bias", "conv_in.bias"), 641 | ("out.0.weight", "conv_norm_out.weight"), 642 | ("out.0.bias", "conv_norm_out.bias"), 643 | ("out.2.weight", "conv_out.weight"), 644 | ("out.2.bias", "conv_out.bias"), 645 | ] 646 | 647 | unet_conversion_map_resnet = [ 648 | # (stable-diffusion, HF Diffusers) 649 | ("in_layers.0", "norm1"), 650 | ("in_layers.2", "conv1"), 651 | ("out_layers.0", "norm2"), 652 | ("out_layers.3", "conv2"), 653 | ("emb_layers.1", "time_emb_proj"), 654 | ("skip_connection", "conv_shortcut"), 655 | ] 656 | 657 | unet_conversion_map_layer = [] 658 | for i in range(4): 659 | # loop over downblocks/upblocks 660 | 661 | for j in range(2): 662 | # loop over resnets/attentions for downblocks 663 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 664 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 665 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 666 | 667 | if i < 3: 668 | # no attention layers in down_blocks.3 669 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 670 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 671 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 672 | 673 | for j in range(3): 674 | # loop over resnets/attentions for upblocks 675 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 676 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 677 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 678 | 679 | if i > 0: 680 | # no attention layers in up_blocks.0 681 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 682 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 683 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 684 | 685 | if i < 3: 686 | # no downsample in down_blocks.3 687 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 688 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 689 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 690 | 691 | # no upsample in up_blocks.3 692 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 693 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." 694 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 695 | 696 | hf_mid_atn_prefix = "mid_block.attentions.0." 697 | sd_mid_atn_prefix = "middle_block.1." 698 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 699 | 700 | for j in range(2): 701 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 702 | sd_mid_res_prefix = f"middle_block.{2*j}." 703 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 704 | 705 | # buyer beware: this is a *brittle* function, 706 | # and correct output requires that all of these pieces interact in 707 | # the exact order in which I have arranged them. 708 | mapping = {k: k for k in unet_state_dict.keys()} 709 | for sd_name, hf_name in unet_conversion_map: 710 | mapping[hf_name] = sd_name 711 | for k, v in mapping.items(): 712 | if "resnets" in k: 713 | for sd_part, hf_part in unet_conversion_map_resnet: 714 | v = v.replace(hf_part, sd_part) 715 | mapping[k] = v 716 | for k, v in mapping.items(): 717 | for sd_part, hf_part in unet_conversion_map_layer: 718 | v = v.replace(hf_part, sd_part) 719 | mapping[k] = v 720 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} 721 | 722 | if v2: 723 | conv_transformer_to_linear(new_state_dict) 724 | 725 | return new_state_dict 726 | 727 | 728 | # ================# 729 | # VAE Conversion # 730 | # ================# 731 | 732 | def reshape_weight_for_sd(w): 733 | # convert HF linear weights to SD conv2d weights 734 | return w.reshape(*w.shape, 1, 1) 735 | 736 | 737 | def convert_vae_state_dict(vae_state_dict): 738 | vae_conversion_map = [ 739 | # (stable-diffusion, HF Diffusers) 740 | ("nin_shortcut", "conv_shortcut"), 741 | ("norm_out", "conv_norm_out"), 742 | ("mid.attn_1.", "mid_block.attentions.0."), 743 | ] 744 | 745 | for i in range(4): 746 | # down_blocks have two resnets 747 | for j in range(2): 748 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 749 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 750 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) 751 | 752 | if i < 3: 753 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 754 | sd_downsample_prefix = f"down.{i}.downsample." 755 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) 756 | 757 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 758 | sd_upsample_prefix = f"up.{3-i}.upsample." 759 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) 760 | 761 | # up_blocks have three resnets 762 | # also, up blocks in hf are numbered in reverse from sd 763 | for j in range(3): 764 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 765 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 766 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) 767 | 768 | # this part accounts for mid blocks in both the encoder and the decoder 769 | for i in range(2): 770 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 771 | sd_mid_res_prefix = f"mid.block_{i+1}." 772 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) 773 | 774 | vae_conversion_map_attn = [ 775 | # (stable-diffusion, HF Diffusers) 776 | ("norm.", "group_norm."), 777 | ("q.", "query."), 778 | ("k.", "key."), 779 | ("v.", "value."), 780 | ("proj_out.", "proj_attn."), 781 | ] 782 | 783 | mapping = {k: k for k in vae_state_dict.keys()} 784 | for k, v in mapping.items(): 785 | for sd_part, hf_part in vae_conversion_map: 786 | v = v.replace(hf_part, sd_part) 787 | mapping[k] = v 788 | for k, v in mapping.items(): 789 | if "attentions" in k: 790 | for sd_part, hf_part in vae_conversion_map_attn: 791 | v = v.replace(hf_part, sd_part) 792 | mapping[k] = v 793 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} 794 | weights_to_convert = ["q", "k", "v", "proj_out"] 795 | 796 | for k, v in new_state_dict.items(): 797 | for weight_name in weights_to_convert: 798 | if f"mid.attn_1.{weight_name}.weight" in k: 799 | new_state_dict[k] = reshape_weight_for_sd(v) 800 | 801 | return new_state_dict 802 | 803 | 804 | # endregion 805 | 806 | 807 | def load_checkpoint_with_text_encoder_conversion(ckpt_path): 808 | # text encoderの格納形式が違うモデルに対応する ('text_model'がない) 809 | TEXT_ENCODER_KEY_REPLACEMENTS = [ 810 | ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), 811 | ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), 812 | ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') 813 | ] 814 | 815 | checkpoint = torch.load(ckpt_path, map_location="cuda") 816 | state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint 817 | key_reps = [] 818 | for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: 819 | for key in state_dict.keys(): 820 | if key.startswith(rep_from): 821 | new_key = rep_to + key[len(rep_from):] 822 | key_reps.append((key, new_key)) 823 | 824 | for key, new_key in key_reps: 825 | state_dict[new_key] = state_dict[key] 826 | del state_dict[key] 827 | 828 | return checkpoint 829 | 830 | 831 | # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 832 | def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): 833 | 834 | checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) 835 | state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint 836 | if dtype is not None: 837 | for k, v in state_dict.items(): 838 | if type(v) is torch.Tensor: 839 | state_dict[k] = v.to(dtype) 840 | 841 | # Convert the UNet2DConditionModel model. 842 | unet_config = create_unet_diffusers_config(v2) 843 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) 844 | 845 | unet = UNet2DConditionModel(**unet_config) 846 | info = unet.load_state_dict(converted_unet_checkpoint) 847 | 848 | 849 | # Convert the VAE model. 850 | vae_config = create_vae_diffusers_config() 851 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) 852 | 853 | vae = AutoencoderKL(**vae_config) 854 | info = vae.load_state_dict(converted_vae_checkpoint) 855 | 856 | 857 | # convert text_model 858 | if v2: 859 | converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) 860 | cfg = CLIPTextConfig( 861 | vocab_size=49408, 862 | hidden_size=1024, 863 | intermediate_size=4096, 864 | num_hidden_layers=23, 865 | num_attention_heads=16, 866 | max_position_embeddings=77, 867 | hidden_act="gelu", 868 | layer_norm_eps=1e-05, 869 | dropout=0.0, 870 | attention_dropout=0.0, 871 | initializer_range=0.02, 872 | initializer_factor=1.0, 873 | pad_token_id=1, 874 | bos_token_id=0, 875 | eos_token_id=2, 876 | model_type="clip_text_model", 877 | projection_dim=512, 878 | torch_dtype="float32", 879 | transformers_version="4.25.0.dev0", 880 | ) 881 | text_model = CLIPTextModel._from_config(cfg) 882 | info = text_model.load_state_dict(converted_text_encoder_checkpoint) 883 | else: 884 | converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) 885 | cfg = CLIPTextConfig( 886 | vocab_size=49408, 887 | hidden_size=768, 888 | intermediate_size=3072, 889 | num_hidden_layers=12, 890 | num_attention_heads=12, 891 | max_position_embeddings=77, 892 | hidden_act="quick_gelu", 893 | layer_norm_eps=1e-05, 894 | dropout=0.0, 895 | attention_dropout=0.0, 896 | initializer_range=0.02, 897 | initializer_factor=1.0, 898 | pad_token_id=1, 899 | bos_token_id=0, 900 | eos_token_id=2, 901 | model_type="clip_text_model", 902 | projection_dim=768, 903 | torch_dtype="float32", 904 | transformers_version="4.16.0.dev0", 905 | ) 906 | 907 | 908 | text_model = CLIPTextModel._from_config(cfg) 909 | #text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") 910 | info = text_model.load_state_dict(converted_text_encoder_checkpoint) 911 | 912 | 913 | return text_model, vae, unet 914 | 915 | 916 | def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): 917 | def convert_key(key): 918 | # position_idsの除去 919 | if ".position_ids" in key: 920 | return None 921 | 922 | # common 923 | key = key.replace("text_model.encoder.", "transformer.") 924 | key = key.replace("text_model.", "") 925 | if "layers" in key: 926 | # resblocks conversion 927 | key = key.replace(".layers.", ".resblocks.") 928 | if ".layer_norm" in key: 929 | key = key.replace(".layer_norm", ".ln_") 930 | elif ".mlp." in key: 931 | key = key.replace(".fc1.", ".c_fc.") 932 | key = key.replace(".fc2.", ".c_proj.") 933 | elif '.self_attn.out_proj' in key: 934 | key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") 935 | elif '.self_attn.' in key: 936 | key = None # 特殊なので後で処理する 937 | else: 938 | raise ValueError(f"unexpected key in DiffUsers model: {key}") 939 | elif '.position_embedding' in key: 940 | key = key.replace("embeddings.position_embedding.weight", "positional_embedding") 941 | elif '.token_embedding' in key: 942 | key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") 943 | elif 'final_layer_norm' in key: 944 | key = key.replace("final_layer_norm", "ln_final") 945 | return key 946 | 947 | keys = list(checkpoint.keys()) 948 | new_sd = {} 949 | for key in keys: 950 | new_key = convert_key(key) 951 | if new_key is None: 952 | continue 953 | new_sd[new_key] = checkpoint[key] 954 | 955 | # attnの変換 956 | for key in keys: 957 | if 'layers' in key and 'q_proj' in key: 958 | # 三つを結合 959 | key_q = key 960 | key_k = key.replace("q_proj", "k_proj") 961 | key_v = key.replace("q_proj", "v_proj") 962 | 963 | value_q = checkpoint[key_q] 964 | value_k = checkpoint[key_k] 965 | value_v = checkpoint[key_v] 966 | value = torch.cat([value_q, value_k, value_v]) 967 | 968 | new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") 969 | new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") 970 | new_sd[new_key] = value 971 | 972 | # 最後の層などを捏造するか 973 | if make_dummy_weights: 974 | 975 | keys = list(new_sd.keys()) 976 | for key in keys: 977 | if key.startswith("transformer.resblocks.22."): 978 | new_sd[key.replace(".22.", ".23.")] = new_sd[key] 979 | 980 | # Diffusersに含まれない重みを作っておく 981 | new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) 982 | new_sd['logit_scale'] = torch.tensor(1) 983 | 984 | return new_sd 985 | 986 | 987 | def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): 988 | if ckpt_path is not None: 989 | # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む 990 | checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) 991 | state_dict = checkpoint["state_dict"] 992 | strict = True 993 | else: 994 | # 新しく作る 995 | checkpoint = {} 996 | state_dict = {} 997 | strict = False 998 | 999 | def update_sd(prefix, sd): 1000 | for k, v in sd.items(): 1001 | key = prefix + k 1002 | assert not strict or key in state_dict, f"Illegal key in save SD: {key}" 1003 | if save_dtype is not None: 1004 | v = v.detach().clone().to("cpu").to(save_dtype) 1005 | state_dict[key] = v 1006 | 1007 | # Convert the UNet model 1008 | unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) 1009 | update_sd("model.diffusion_model.", unet_state_dict) 1010 | 1011 | # Convert the text encoder model 1012 | if v2: 1013 | make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる 1014 | text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) 1015 | update_sd("cond_stage_model.model.", text_enc_dict) 1016 | else: 1017 | text_enc_dict = text_encoder.state_dict() 1018 | update_sd("cond_stage_model.transformer.", text_enc_dict) 1019 | 1020 | # Convert the VAE 1021 | if vae is not None: 1022 | vae_dict = convert_vae_state_dict(vae.state_dict()) 1023 | update_sd("first_stage_model.", vae_dict) 1024 | 1025 | # Put together new checkpoint 1026 | key_count = len(state_dict.keys()) 1027 | new_ckpt = {'state_dict': state_dict} 1028 | 1029 | if 'epoch' in checkpoint: 1030 | epochs += checkpoint['epoch'] 1031 | if 'global_step' in checkpoint: 1032 | steps += checkpoint['global_step'] 1033 | 1034 | new_ckpt['epoch'] = epochs 1035 | new_ckpt['global_step'] = steps 1036 | 1037 | torch.save(new_ckpt, output_file) 1038 | 1039 | return key_count 1040 | 1041 | 1042 | def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, vae=None): 1043 | if vae is None: 1044 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") 1045 | 1046 | pipeline = StableDiffusionPipeline( 1047 | unet=unet, 1048 | text_encoder=text_encoder, 1049 | vae=vae, 1050 | scheduler = DDIMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler"), 1051 | tokenizer=CLIPTokenizer.from_pretrained("refmdl", subfolder="tokenizer"), 1052 | ) 1053 | pipeline.save_pretrained(output_dir) 1054 | 1055 | 1056 | 1057 | def convert(args): 1058 | print("[1;32mConverting to Diffusers ...") 1059 | load_dtype = torch.float16 if args.fp16 else None 1060 | 1061 | save_dtype = None 1062 | if args.fp16: 1063 | save_dtype = torch.float16 1064 | elif args.bf16: 1065 | save_dtype = torch.bfloat16 1066 | elif args.float: 1067 | save_dtype = torch.float 1068 | 1069 | is_load_ckpt = os.path.isfile(args.model_to_load) 1070 | is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 1071 | 1072 | assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint" 1073 | assert is_save_ckpt is not None, f"reference model is required to save as Diffusers" 1074 | 1075 | # モデルを読み込む 1076 | msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) 1077 | 1078 | 1079 | if is_load_ckpt: 1080 | v2_model = args.v2 1081 | text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load) 1082 | else: 1083 | pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None) 1084 | text_encoder = pipe.text_encoder 1085 | vae = pipe.vae 1086 | unet = pipe.unet 1087 | 1088 | if args.v1 == args.v2: 1089 | # 自動判定する 1090 | v2_model = unet.config.cross_attention_dim == 1024 1091 | #print("checking model version: model is " + ('v2' if v2_model else 'v1')) 1092 | else: 1093 | v2_model = args.v1 1094 | 1095 | # 変換して保存する 1096 | msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" 1097 | 1098 | 1099 | if is_save_ckpt: 1100 | original_model = args.model_to_load if is_load_ckpt else None 1101 | key_count = save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet, 1102 | original_model, args.epoch, args.global_step, save_dtype, vae) 1103 | 1104 | else: 1105 | save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, vae) 1106 | 1107 | 1108 | 1109 | if __name__ == '__main__': 1110 | parser = argparse.ArgumentParser() 1111 | parser.add_argument("--v1", action='store_true', 1112 | help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む') 1113 | parser.add_argument("--v2", action='store_true', 1114 | help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む') 1115 | parser.add_argument("--fp16", action='store_true', 1116 | help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)') 1117 | parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)') 1118 | parser.add_argument("--float", action='store_true', 1119 | help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)') 1120 | parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値') 1121 | parser.add_argument("--global_step", type=int, default=0, 1122 | help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値') 1123 | 1124 | parser.add_argument("model_to_load", type=str, default=None, 1125 | help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ") 1126 | parser.add_argument("model_to_save", type=str, default=None, 1127 | help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存") 1128 | 1129 | args = parser.parse_args() 1130 | convert(args) 1131 | -------------------------------------------------------------------------------- /Dreambooth/convertosd.py: -------------------------------------------------------------------------------- 1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. 2 | # *Only* converts the UNet, VAE, and Text Encoder. 3 | # Does not convert optimizer state or any other thing. 4 | # Written by jachiam 5 | 6 | import argparse 7 | import os.path as osp 8 | 9 | import torch 10 | 11 | 12 | # =================# 13 | # UNet Conversion # 14 | # =================# 15 | 16 | unet_conversion_map = [ 17 | # (stable-diffusion, HF Diffusers) 18 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), 19 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), 20 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), 21 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), 22 | ("input_blocks.0.0.weight", "conv_in.weight"), 23 | ("input_blocks.0.0.bias", "conv_in.bias"), 24 | ("out.0.weight", "conv_norm_out.weight"), 25 | ("out.0.bias", "conv_norm_out.bias"), 26 | ("out.2.weight", "conv_out.weight"), 27 | ("out.2.bias", "conv_out.bias"), 28 | ] 29 | 30 | unet_conversion_map_resnet = [ 31 | # (stable-diffusion, HF Diffusers) 32 | ("in_layers.0", "norm1"), 33 | ("in_layers.2", "conv1"), 34 | ("out_layers.0", "norm2"), 35 | ("out_layers.3", "conv2"), 36 | ("emb_layers.1", "time_emb_proj"), 37 | ("skip_connection", "conv_shortcut"), 38 | ] 39 | 40 | unet_conversion_map_layer = [] 41 | # hardcoded number of downblocks and resnets/attentions... 42 | # would need smarter logic for other networks. 43 | for i in range(4): 44 | # loop over downblocks/upblocks 45 | 46 | for j in range(2): 47 | # loop over resnets/attentions for downblocks 48 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 49 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 50 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 51 | 52 | if i < 3: 53 | # no attention layers in down_blocks.3 54 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 55 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 56 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 57 | 58 | for j in range(3): 59 | # loop over resnets/attentions for upblocks 60 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 61 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 62 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 63 | 64 | if i > 0: 65 | # no attention layers in up_blocks.0 66 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 67 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 68 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 69 | 70 | if i < 3: 71 | # no downsample in down_blocks.3 72 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 73 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 74 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 75 | 76 | # no upsample in up_blocks.3 77 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 78 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." 79 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 80 | 81 | hf_mid_atn_prefix = "mid_block.attentions.0." 82 | sd_mid_atn_prefix = "middle_block.1." 83 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 84 | 85 | for j in range(2): 86 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 87 | sd_mid_res_prefix = f"middle_block.{2*j}." 88 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 89 | 90 | 91 | def convert_unet_state_dict(unet_state_dict): 92 | # buyer beware: this is a *brittle* function, 93 | # and correct output requires that all of these pieces interact in 94 | # the exact order in which I have arranged them. 95 | mapping = {k: k for k in unet_state_dict.keys()} 96 | for sd_name, hf_name in unet_conversion_map: 97 | mapping[hf_name] = sd_name 98 | for k, v in mapping.items(): 99 | if "resnets" in k: 100 | for sd_part, hf_part in unet_conversion_map_resnet: 101 | v = v.replace(hf_part, sd_part) 102 | mapping[k] = v 103 | for k, v in mapping.items(): 104 | for sd_part, hf_part in unet_conversion_map_layer: 105 | v = v.replace(hf_part, sd_part) 106 | mapping[k] = v 107 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} 108 | return new_state_dict 109 | 110 | 111 | # ================# 112 | # VAE Conversion # 113 | # ================# 114 | 115 | vae_conversion_map = [ 116 | # (stable-diffusion, HF Diffusers) 117 | ("nin_shortcut", "conv_shortcut"), 118 | ("norm_out", "conv_norm_out"), 119 | ("mid.attn_1.", "mid_block.attentions.0."), 120 | ] 121 | 122 | for i in range(4): 123 | # down_blocks have two resnets 124 | for j in range(2): 125 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 126 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 127 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) 128 | 129 | if i < 3: 130 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 131 | sd_downsample_prefix = f"down.{i}.downsample." 132 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) 133 | 134 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 135 | sd_upsample_prefix = f"up.{3-i}.upsample." 136 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) 137 | 138 | # up_blocks have three resnets 139 | # also, up blocks in hf are numbered in reverse from sd 140 | for j in range(3): 141 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 142 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 143 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) 144 | 145 | # this part accounts for mid blocks in both the encoder and the decoder 146 | for i in range(2): 147 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 148 | sd_mid_res_prefix = f"mid.block_{i+1}." 149 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) 150 | 151 | 152 | vae_conversion_map_attn = [ 153 | # (stable-diffusion, HF Diffusers) 154 | ("norm.", "group_norm."), 155 | ("q.", "query."), 156 | ("k.", "key."), 157 | ("v.", "value."), 158 | ("proj_out.", "proj_attn."), 159 | ] 160 | 161 | 162 | def reshape_weight_for_sd(w): 163 | # convert HF linear weights to SD conv2d weights 164 | return w.reshape(*w.shape, 1, 1) 165 | 166 | 167 | def convert_vae_state_dict(vae_state_dict): 168 | mapping = {k: k for k in vae_state_dict.keys()} 169 | for k, v in mapping.items(): 170 | for sd_part, hf_part in vae_conversion_map: 171 | v = v.replace(hf_part, sd_part) 172 | mapping[k] = v 173 | for k, v in mapping.items(): 174 | if "attentions" in k: 175 | for sd_part, hf_part in vae_conversion_map_attn: 176 | v = v.replace(hf_part, sd_part) 177 | mapping[k] = v 178 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} 179 | weights_to_convert = ["q", "k", "v", "proj_out"] 180 | print("[1;32mConverting to CKPT ...") 181 | for k, v in new_state_dict.items(): 182 | for weight_name in weights_to_convert: 183 | if f"mid.attn_1.{weight_name}.weight" in k: 184 | new_state_dict[k] = reshape_weight_for_sd(v) 185 | return new_state_dict 186 | 187 | 188 | # =========================# 189 | # Text Encoder Conversion # 190 | # =========================# 191 | # pretty much a no-op 192 | 193 | 194 | def convert_text_enc_state_dict(text_enc_dict): 195 | return text_enc_dict 196 | 197 | 198 | if __name__ == "__main__": 199 | 200 | 201 | model_path = "" 202 | checkpoint_path= "" 203 | 204 | unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") 205 | vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") 206 | text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") 207 | 208 | # Convert the UNet model 209 | unet_state_dict = torch.load(unet_path, map_location='cpu') 210 | unet_state_dict = convert_unet_state_dict(unet_state_dict) 211 | unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} 212 | 213 | # Convert the VAE model 214 | vae_state_dict = torch.load(vae_path, map_location='cpu') 215 | vae_state_dict = convert_vae_state_dict(vae_state_dict) 216 | vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} 217 | 218 | # Convert the text encoder model 219 | text_enc_dict = torch.load(text_enc_path, map_location='cpu') 220 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict) 221 | text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} 222 | 223 | # Put together new checkpoint 224 | state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} 225 | 226 | state_dict = {k:v.half() for k,v in state_dict.items()} 227 | state_dict = {"state_dict": state_dict} 228 | torch.save(state_dict, checkpoint_path) 229 | -------------------------------------------------------------------------------- /Dreambooth/convertosdv2.py: -------------------------------------------------------------------------------- 1 | print("[1;32mConverting to CKPT ...") 2 | import argparse 3 | import os 4 | import torch 5 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig 6 | from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel 7 | 8 | 9 | # DiffUsers版StableDiffusionのモデルパラメータ 10 | NUM_TRAIN_TIMESTEPS = 1000 11 | BETA_START = 0.00085 12 | BETA_END = 0.0120 13 | 14 | UNET_PARAMS_MODEL_CHANNELS = 320 15 | UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] 16 | UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] 17 | UNET_PARAMS_IMAGE_SIZE = 64 18 | UNET_PARAMS_IN_CHANNELS = 4 19 | UNET_PARAMS_OUT_CHANNELS = 4 20 | UNET_PARAMS_NUM_RES_BLOCKS = 2 21 | UNET_PARAMS_CONTEXT_DIM = 768 22 | UNET_PARAMS_NUM_HEADS = 8 23 | 24 | VAE_PARAMS_Z_CHANNELS = 4 25 | VAE_PARAMS_RESOLUTION = 768 26 | VAE_PARAMS_IN_CHANNELS = 3 27 | VAE_PARAMS_OUT_CH = 3 28 | VAE_PARAMS_CH = 128 29 | VAE_PARAMS_CH_MULT = [1, 2, 4, 4] 30 | VAE_PARAMS_NUM_RES_BLOCKS = 2 31 | 32 | # V2 33 | V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] 34 | V2_UNET_PARAMS_CONTEXT_DIM = 1024 35 | 36 | 37 | # region StableDiffusion->Diffusersの変換コード 38 | # convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0) 39 | 40 | 41 | def shave_segments(path, n_shave_prefix_segments=1): 42 | """ 43 | Removes segments. Positive values shave the first segments, negative shave the last segments. 44 | """ 45 | if n_shave_prefix_segments >= 0: 46 | return ".".join(path.split(".")[n_shave_prefix_segments:]) 47 | else: 48 | return ".".join(path.split(".")[:n_shave_prefix_segments]) 49 | 50 | 51 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): 52 | """ 53 | Updates paths inside resnets to the new naming scheme (local renaming) 54 | """ 55 | mapping = [] 56 | for old_item in old_list: 57 | new_item = old_item.replace("in_layers.0", "norm1") 58 | new_item = new_item.replace("in_layers.2", "conv1") 59 | 60 | new_item = new_item.replace("out_layers.0", "norm2") 61 | new_item = new_item.replace("out_layers.3", "conv2") 62 | 63 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") 64 | new_item = new_item.replace("skip_connection", "conv_shortcut") 65 | 66 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 67 | 68 | mapping.append({"old": old_item, "new": new_item}) 69 | 70 | return mapping 71 | 72 | 73 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): 74 | """ 75 | Updates paths inside resnets to the new naming scheme (local renaming) 76 | """ 77 | mapping = [] 78 | for old_item in old_list: 79 | new_item = old_item 80 | 81 | new_item = new_item.replace("nin_shortcut", "conv_shortcut") 82 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 83 | 84 | mapping.append({"old": old_item, "new": new_item}) 85 | 86 | return mapping 87 | 88 | 89 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): 90 | """ 91 | Updates paths inside attentions to the new naming scheme (local renaming) 92 | """ 93 | mapping = [] 94 | for old_item in old_list: 95 | new_item = old_item 96 | 97 | # new_item = new_item.replace('norm.weight', 'group_norm.weight') 98 | # new_item = new_item.replace('norm.bias', 'group_norm.bias') 99 | 100 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') 101 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') 102 | 103 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 104 | 105 | mapping.append({"old": old_item, "new": new_item}) 106 | 107 | return mapping 108 | 109 | 110 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): 111 | """ 112 | Updates paths inside attentions to the new naming scheme (local renaming) 113 | """ 114 | mapping = [] 115 | for old_item in old_list: 116 | new_item = old_item 117 | 118 | new_item = new_item.replace("norm.weight", "group_norm.weight") 119 | new_item = new_item.replace("norm.bias", "group_norm.bias") 120 | 121 | new_item = new_item.replace("q.weight", "query.weight") 122 | new_item = new_item.replace("q.bias", "query.bias") 123 | 124 | new_item = new_item.replace("k.weight", "key.weight") 125 | new_item = new_item.replace("k.bias", "key.bias") 126 | 127 | new_item = new_item.replace("v.weight", "value.weight") 128 | new_item = new_item.replace("v.bias", "value.bias") 129 | 130 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight") 131 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias") 132 | 133 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 134 | 135 | mapping.append({"old": old_item, "new": new_item}) 136 | 137 | return mapping 138 | 139 | 140 | def assign_to_checkpoint( 141 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None 142 | ): 143 | """ 144 | This does the final conversion step: take locally converted weights and apply a global renaming 145 | to them. It splits attention layers, and takes into account additional replacements 146 | that may arise. 147 | 148 | Assigns the weights to the new checkpoint. 149 | """ 150 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." 151 | 152 | # Splits the attention layers into three variables. 153 | if attention_paths_to_split is not None: 154 | for path, path_map in attention_paths_to_split.items(): 155 | old_tensor = old_checkpoint[path] 156 | channels = old_tensor.shape[0] // 3 157 | 158 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) 159 | 160 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 161 | 162 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) 163 | query, key, value = old_tensor.split(channels // num_heads, dim=1) 164 | 165 | checkpoint[path_map["query"]] = query.reshape(target_shape) 166 | checkpoint[path_map["key"]] = key.reshape(target_shape) 167 | checkpoint[path_map["value"]] = value.reshape(target_shape) 168 | 169 | for path in paths: 170 | new_path = path["new"] 171 | 172 | # These have already been assigned 173 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: 174 | continue 175 | 176 | # Global renaming happens here 177 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") 178 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") 179 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") 180 | 181 | if additional_replacements is not None: 182 | for replacement in additional_replacements: 183 | new_path = new_path.replace(replacement["old"], replacement["new"]) 184 | 185 | # proj_attn.weight has to be converted from conv 1D to linear 186 | if "proj_attn.weight" in new_path: 187 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] 188 | else: 189 | checkpoint[new_path] = old_checkpoint[path["old"]] 190 | 191 | 192 | def conv_attn_to_linear(checkpoint): 193 | keys = list(checkpoint.keys()) 194 | attn_keys = ["query.weight", "key.weight", "value.weight"] 195 | for key in keys: 196 | if ".".join(key.split(".")[-2:]) in attn_keys: 197 | if checkpoint[key].ndim > 2: 198 | checkpoint[key] = checkpoint[key][:, :, 0, 0] 199 | elif "proj_attn.weight" in key: 200 | if checkpoint[key].ndim > 2: 201 | checkpoint[key] = checkpoint[key][:, :, 0] 202 | 203 | 204 | def linear_transformer_to_conv(checkpoint): 205 | keys = list(checkpoint.keys()) 206 | tf_keys = ["proj_in.weight", "proj_out.weight"] 207 | for key in keys: 208 | if ".".join(key.split(".")[-2:]) in tf_keys: 209 | if checkpoint[key].ndim == 2: 210 | checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) 211 | 212 | 213 | def convert_ldm_unet_checkpoint(v2, checkpoint, config): 214 | """ 215 | Takes a state dict and a config, and returns a converted checkpoint. 216 | """ 217 | 218 | # extract state_dict for UNet 219 | unet_state_dict = {} 220 | unet_key = "model.diffusion_model." 221 | keys = list(checkpoint.keys()) 222 | for key in keys: 223 | if key.startswith(unet_key): 224 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) 225 | 226 | new_checkpoint = {} 227 | 228 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] 229 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] 230 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] 231 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] 232 | 233 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] 234 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] 235 | 236 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] 237 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] 238 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] 239 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] 240 | 241 | # Retrieves the keys for the input blocks only 242 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) 243 | input_blocks = { 244 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] 245 | for layer_id in range(num_input_blocks) 246 | } 247 | 248 | # Retrieves the keys for the middle blocks only 249 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) 250 | middle_blocks = { 251 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] 252 | for layer_id in range(num_middle_blocks) 253 | } 254 | 255 | # Retrieves the keys for the output blocks only 256 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) 257 | output_blocks = { 258 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] 259 | for layer_id in range(num_output_blocks) 260 | } 261 | 262 | for i in range(1, num_input_blocks): 263 | block_id = (i - 1) // (config["layers_per_block"] + 1) 264 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) 265 | 266 | resnets = [ 267 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key 268 | ] 269 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] 270 | 271 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict: 272 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( 273 | f"input_blocks.{i}.0.op.weight" 274 | ) 275 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( 276 | f"input_blocks.{i}.0.op.bias" 277 | ) 278 | 279 | paths = renew_resnet_paths(resnets) 280 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} 281 | assign_to_checkpoint( 282 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 283 | ) 284 | 285 | if len(attentions): 286 | paths = renew_attention_paths(attentions) 287 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} 288 | assign_to_checkpoint( 289 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 290 | ) 291 | 292 | resnet_0 = middle_blocks[0] 293 | attentions = middle_blocks[1] 294 | resnet_1 = middle_blocks[2] 295 | 296 | resnet_0_paths = renew_resnet_paths(resnet_0) 297 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) 298 | 299 | resnet_1_paths = renew_resnet_paths(resnet_1) 300 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) 301 | 302 | attentions_paths = renew_attention_paths(attentions) 303 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} 304 | assign_to_checkpoint( 305 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 306 | ) 307 | 308 | for i in range(num_output_blocks): 309 | block_id = i // (config["layers_per_block"] + 1) 310 | layer_in_block_id = i % (config["layers_per_block"] + 1) 311 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] 312 | output_block_list = {} 313 | 314 | for layer in output_block_layers: 315 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) 316 | if layer_id in output_block_list: 317 | output_block_list[layer_id].append(layer_name) 318 | else: 319 | output_block_list[layer_id] = [layer_name] 320 | 321 | if len(output_block_list) > 1: 322 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] 323 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] 324 | 325 | resnet_0_paths = renew_resnet_paths(resnets) 326 | paths = renew_resnet_paths(resnets) 327 | 328 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} 329 | assign_to_checkpoint( 330 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 331 | ) 332 | 333 | if ["conv.weight", "conv.bias"] in output_block_list.values(): 334 | index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) 335 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ 336 | f"output_blocks.{i}.{index}.conv.weight" 337 | ] 338 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ 339 | f"output_blocks.{i}.{index}.conv.bias" 340 | ] 341 | 342 | # Clear attentions as they have been attributed above. 343 | if len(attentions) == 2: 344 | attentions = [] 345 | 346 | if len(attentions): 347 | paths = renew_attention_paths(attentions) 348 | meta_path = { 349 | "old": f"output_blocks.{i}.1", 350 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", 351 | } 352 | assign_to_checkpoint( 353 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 354 | ) 355 | else: 356 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) 357 | for path in resnet_0_paths: 358 | old_path = ".".join(["output_blocks", str(i), path["old"]]) 359 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) 360 | 361 | new_checkpoint[new_path] = unet_state_dict[old_path] 362 | 363 | # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する 364 | if v2: 365 | linear_transformer_to_conv(new_checkpoint) 366 | 367 | return new_checkpoint 368 | 369 | 370 | def convert_ldm_vae_checkpoint(checkpoint, config): 371 | # extract state dict for VAE 372 | vae_state_dict = {} 373 | vae_key = "first_stage_model." 374 | keys = list(checkpoint.keys()) 375 | for key in keys: 376 | if key.startswith(vae_key): 377 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) 378 | # if len(vae_state_dict) == 0: 379 | # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict 380 | # vae_state_dict = checkpoint 381 | 382 | new_checkpoint = {} 383 | 384 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] 385 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] 386 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] 387 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] 388 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] 389 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] 390 | 391 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] 392 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] 393 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] 394 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] 395 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] 396 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] 397 | 398 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] 399 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] 400 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] 401 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] 402 | 403 | # Retrieves the keys for the encoder down blocks only 404 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) 405 | down_blocks = { 406 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) 407 | } 408 | 409 | # Retrieves the keys for the decoder up blocks only 410 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) 411 | up_blocks = { 412 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) 413 | } 414 | 415 | for i in range(num_down_blocks): 416 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] 417 | 418 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: 419 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( 420 | f"encoder.down.{i}.downsample.conv.weight" 421 | ) 422 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( 423 | f"encoder.down.{i}.downsample.conv.bias" 424 | ) 425 | 426 | paths = renew_vae_resnet_paths(resnets) 427 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} 428 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 429 | 430 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] 431 | num_mid_res_blocks = 2 432 | for i in range(1, num_mid_res_blocks + 1): 433 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] 434 | 435 | paths = renew_vae_resnet_paths(resnets) 436 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 437 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 438 | 439 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] 440 | paths = renew_vae_attention_paths(mid_attentions) 441 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 442 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 443 | conv_attn_to_linear(new_checkpoint) 444 | 445 | for i in range(num_up_blocks): 446 | block_id = num_up_blocks - 1 - i 447 | resnets = [ 448 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key 449 | ] 450 | 451 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: 452 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ 453 | f"decoder.up.{block_id}.upsample.conv.weight" 454 | ] 455 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ 456 | f"decoder.up.{block_id}.upsample.conv.bias" 457 | ] 458 | 459 | paths = renew_vae_resnet_paths(resnets) 460 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} 461 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 462 | 463 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] 464 | num_mid_res_blocks = 2 465 | for i in range(1, num_mid_res_blocks + 1): 466 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] 467 | 468 | paths = renew_vae_resnet_paths(resnets) 469 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 470 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 471 | 472 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] 473 | paths = renew_vae_attention_paths(mid_attentions) 474 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 475 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 476 | conv_attn_to_linear(new_checkpoint) 477 | return new_checkpoint 478 | 479 | 480 | def create_unet_diffusers_config(v2): 481 | """ 482 | Creates a config for the diffusers based on the config of the LDM model. 483 | """ 484 | # unet_params = original_config.model.params.unet_config.params 485 | 486 | block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] 487 | 488 | down_block_types = [] 489 | resolution = 1 490 | for i in range(len(block_out_channels)): 491 | block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" 492 | down_block_types.append(block_type) 493 | if i != len(block_out_channels) - 1: 494 | resolution *= 2 495 | 496 | up_block_types = [] 497 | for i in range(len(block_out_channels)): 498 | block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" 499 | up_block_types.append(block_type) 500 | resolution //= 2 501 | 502 | config = dict( 503 | sample_size=UNET_PARAMS_IMAGE_SIZE, 504 | in_channels=UNET_PARAMS_IN_CHANNELS, 505 | out_channels=UNET_PARAMS_OUT_CHANNELS, 506 | down_block_types=tuple(down_block_types), 507 | up_block_types=tuple(up_block_types), 508 | block_out_channels=tuple(block_out_channels), 509 | layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, 510 | cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, 511 | attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, 512 | ) 513 | 514 | return config 515 | 516 | 517 | def create_vae_diffusers_config(): 518 | """ 519 | Creates a config for the diffusers based on the config of the LDM model. 520 | """ 521 | # vae_params = original_config.model.params.first_stage_config.params.ddconfig 522 | # _ = original_config.model.params.first_stage_config.params.embed_dim 523 | block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] 524 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) 525 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) 526 | 527 | config = dict( 528 | sample_size=VAE_PARAMS_RESOLUTION, 529 | in_channels=VAE_PARAMS_IN_CHANNELS, 530 | out_channels=VAE_PARAMS_OUT_CH, 531 | down_block_types=tuple(down_block_types), 532 | up_block_types=tuple(up_block_types), 533 | block_out_channels=tuple(block_out_channels), 534 | latent_channels=VAE_PARAMS_Z_CHANNELS, 535 | layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, 536 | ) 537 | return config 538 | 539 | 540 | def convert_ldm_clip_checkpoint_v1(checkpoint): 541 | keys = list(checkpoint.keys()) 542 | text_model_dict = {} 543 | for key in keys: 544 | if key.startswith("cond_stage_model.transformer"): 545 | text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] 546 | return text_model_dict 547 | 548 | 549 | def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): 550 | # 嫌になるくらい違うぞ! 551 | def convert_key(key): 552 | if not key.startswith("cond_stage_model"): 553 | return None 554 | 555 | # common conversion 556 | key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") 557 | key = key.replace("cond_stage_model.model.", "text_model.") 558 | 559 | if "resblocks" in key: 560 | # resblocks conversion 561 | key = key.replace(".resblocks.", ".layers.") 562 | if ".ln_" in key: 563 | key = key.replace(".ln_", ".layer_norm") 564 | elif ".mlp." in key: 565 | key = key.replace(".c_fc.", ".fc1.") 566 | key = key.replace(".c_proj.", ".fc2.") 567 | elif '.attn.out_proj' in key: 568 | key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") 569 | elif '.attn.in_proj' in key: 570 | key = None # 特殊なので後で処理する 571 | else: 572 | raise ValueError(f"unexpected key in SD: {key}") 573 | elif '.positional_embedding' in key: 574 | key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") 575 | elif '.text_projection' in key: 576 | key = None # 使われない??? 577 | elif '.logit_scale' in key: 578 | key = None # 使われない??? 579 | elif '.token_embedding' in key: 580 | key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") 581 | elif '.ln_final' in key: 582 | key = key.replace(".ln_final", ".final_layer_norm") 583 | return key 584 | 585 | keys = list(checkpoint.keys()) 586 | new_sd = {} 587 | for key in keys: 588 | # remove resblocks 23 589 | if '.resblocks.23.' in key: 590 | continue 591 | new_key = convert_key(key) 592 | if new_key is None: 593 | continue 594 | new_sd[new_key] = checkpoint[key] 595 | 596 | # attnの変換 597 | for key in keys: 598 | if '.resblocks.23.' in key: 599 | continue 600 | if '.resblocks' in key and '.attn.in_proj_' in key: 601 | # 三つに分割 602 | values = torch.chunk(checkpoint[key], 3) 603 | 604 | key_suffix = ".weight" if "weight" in key else ".bias" 605 | key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") 606 | key_pfx = key_pfx.replace("_weight", "") 607 | key_pfx = key_pfx.replace("_bias", "") 608 | key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") 609 | new_sd[key_pfx + "q_proj" + key_suffix] = values[0] 610 | new_sd[key_pfx + "k_proj" + key_suffix] = values[1] 611 | new_sd[key_pfx + "v_proj" + key_suffix] = values[2] 612 | 613 | # position_idsの追加 614 | new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64) 615 | return new_sd 616 | 617 | # endregion 618 | 619 | 620 | # region Diffusers->StableDiffusion の変換コード 621 | # convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0) 622 | 623 | def conv_transformer_to_linear(checkpoint): 624 | keys = list(checkpoint.keys()) 625 | tf_keys = ["proj_in.weight", "proj_out.weight"] 626 | for key in keys: 627 | if ".".join(key.split(".")[-2:]) in tf_keys: 628 | if checkpoint[key].ndim > 2: 629 | checkpoint[key] = checkpoint[key][:, :, 0, 0] 630 | 631 | 632 | def convert_unet_state_dict_to_sd(v2, unet_state_dict): 633 | unet_conversion_map = [ 634 | # (stable-diffusion, HF Diffusers) 635 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), 636 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), 637 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), 638 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), 639 | ("input_blocks.0.0.weight", "conv_in.weight"), 640 | ("input_blocks.0.0.bias", "conv_in.bias"), 641 | ("out.0.weight", "conv_norm_out.weight"), 642 | ("out.0.bias", "conv_norm_out.bias"), 643 | ("out.2.weight", "conv_out.weight"), 644 | ("out.2.bias", "conv_out.bias"), 645 | ] 646 | 647 | unet_conversion_map_resnet = [ 648 | # (stable-diffusion, HF Diffusers) 649 | ("in_layers.0", "norm1"), 650 | ("in_layers.2", "conv1"), 651 | ("out_layers.0", "norm2"), 652 | ("out_layers.3", "conv2"), 653 | ("emb_layers.1", "time_emb_proj"), 654 | ("skip_connection", "conv_shortcut"), 655 | ] 656 | 657 | unet_conversion_map_layer = [] 658 | for i in range(4): 659 | # loop over downblocks/upblocks 660 | 661 | for j in range(2): 662 | # loop over resnets/attentions for downblocks 663 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 664 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 665 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 666 | 667 | if i < 3: 668 | # no attention layers in down_blocks.3 669 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 670 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 671 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 672 | 673 | for j in range(3): 674 | # loop over resnets/attentions for upblocks 675 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 676 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 677 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 678 | 679 | if i > 0: 680 | # no attention layers in up_blocks.0 681 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 682 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 683 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 684 | 685 | if i < 3: 686 | # no downsample in down_blocks.3 687 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 688 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 689 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 690 | 691 | # no upsample in up_blocks.3 692 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 693 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." 694 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 695 | 696 | hf_mid_atn_prefix = "mid_block.attentions.0." 697 | sd_mid_atn_prefix = "middle_block.1." 698 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 699 | 700 | for j in range(2): 701 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 702 | sd_mid_res_prefix = f"middle_block.{2*j}." 703 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 704 | 705 | # buyer beware: this is a *brittle* function, 706 | # and correct output requires that all of these pieces interact in 707 | # the exact order in which I have arranged them. 708 | mapping = {k: k for k in unet_state_dict.keys()} 709 | for sd_name, hf_name in unet_conversion_map: 710 | mapping[hf_name] = sd_name 711 | for k, v in mapping.items(): 712 | if "resnets" in k: 713 | for sd_part, hf_part in unet_conversion_map_resnet: 714 | v = v.replace(hf_part, sd_part) 715 | mapping[k] = v 716 | for k, v in mapping.items(): 717 | for sd_part, hf_part in unet_conversion_map_layer: 718 | v = v.replace(hf_part, sd_part) 719 | mapping[k] = v 720 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} 721 | 722 | if v2: 723 | conv_transformer_to_linear(new_state_dict) 724 | 725 | return new_state_dict 726 | 727 | 728 | # ================# 729 | # VAE Conversion # 730 | # ================# 731 | 732 | def reshape_weight_for_sd(w): 733 | # convert HF linear weights to SD conv2d weights 734 | return w.reshape(*w.shape, 1, 1) 735 | 736 | 737 | def convert_vae_state_dict(vae_state_dict): 738 | vae_conversion_map = [ 739 | # (stable-diffusion, HF Diffusers) 740 | ("nin_shortcut", "conv_shortcut"), 741 | ("norm_out", "conv_norm_out"), 742 | ("mid.attn_1.", "mid_block.attentions.0."), 743 | ] 744 | 745 | for i in range(4): 746 | # down_blocks have two resnets 747 | for j in range(2): 748 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 749 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 750 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) 751 | 752 | if i < 3: 753 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 754 | sd_downsample_prefix = f"down.{i}.downsample." 755 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) 756 | 757 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 758 | sd_upsample_prefix = f"up.{3-i}.upsample." 759 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) 760 | 761 | # up_blocks have three resnets 762 | # also, up blocks in hf are numbered in reverse from sd 763 | for j in range(3): 764 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 765 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 766 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) 767 | 768 | # this part accounts for mid blocks in both the encoder and the decoder 769 | for i in range(2): 770 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 771 | sd_mid_res_prefix = f"mid.block_{i+1}." 772 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) 773 | 774 | vae_conversion_map_attn = [ 775 | # (stable-diffusion, HF Diffusers) 776 | ("norm.", "group_norm."), 777 | ("q.", "query."), 778 | ("k.", "key."), 779 | ("v.", "value."), 780 | ("proj_out.", "proj_attn."), 781 | ] 782 | 783 | mapping = {k: k for k in vae_state_dict.keys()} 784 | for k, v in mapping.items(): 785 | for sd_part, hf_part in vae_conversion_map: 786 | v = v.replace(hf_part, sd_part) 787 | mapping[k] = v 788 | for k, v in mapping.items(): 789 | if "attentions" in k: 790 | for sd_part, hf_part in vae_conversion_map_attn: 791 | v = v.replace(hf_part, sd_part) 792 | mapping[k] = v 793 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} 794 | weights_to_convert = ["q", "k", "v", "proj_out"] 795 | 796 | for k, v in new_state_dict.items(): 797 | for weight_name in weights_to_convert: 798 | if f"mid.attn_1.{weight_name}.weight" in k: 799 | new_state_dict[k] = reshape_weight_for_sd(v) 800 | 801 | return new_state_dict 802 | 803 | 804 | # endregion 805 | 806 | 807 | def load_checkpoint_with_text_encoder_conversion(ckpt_path): 808 | # text encoderの格納形式が違うモデルに対応する ('text_model'がない) 809 | TEXT_ENCODER_KEY_REPLACEMENTS = [ 810 | ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), 811 | ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), 812 | ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') 813 | ] 814 | 815 | checkpoint = torch.load(ckpt_path, map_location="cpu") 816 | state_dict = checkpoint["state_dict"] 817 | key_reps = [] 818 | for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: 819 | for key in state_dict.keys(): 820 | if key.startswith(rep_from): 821 | new_key = rep_to + key[len(rep_from):] 822 | key_reps.append((key, new_key)) 823 | 824 | for key, new_key in key_reps: 825 | state_dict[new_key] = state_dict[key] 826 | del state_dict[key] 827 | 828 | return checkpoint 829 | 830 | 831 | # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 832 | def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): 833 | 834 | checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) 835 | state_dict = checkpoint["state_dict"] 836 | if dtype is not None: 837 | for k, v in state_dict.items(): 838 | if type(v) is torch.Tensor: 839 | state_dict[k] = v.to(dtype) 840 | 841 | # Convert the UNet2DConditionModel model. 842 | unet_config = create_unet_diffusers_config(v2) 843 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) 844 | 845 | unet = UNet2DConditionModel(**unet_config) 846 | info = unet.load_state_dict(converted_unet_checkpoint) 847 | 848 | 849 | # Convert the VAE model. 850 | vae_config = create_vae_diffusers_config() 851 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) 852 | 853 | vae = AutoencoderKL(**vae_config) 854 | info = vae.load_state_dict(converted_vae_checkpoint) 855 | 856 | 857 | # convert text_model 858 | if v2: 859 | converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) 860 | cfg = CLIPTextConfig( 861 | vocab_size=49408, 862 | hidden_size=1024, 863 | intermediate_size=4096, 864 | num_hidden_layers=23, 865 | num_attention_heads=16, 866 | max_position_embeddings=77, 867 | hidden_act="gelu", 868 | layer_norm_eps=1e-05, 869 | dropout=0.0, 870 | attention_dropout=0.0, 871 | initializer_range=0.02, 872 | initializer_factor=1.0, 873 | pad_token_id=1, 874 | bos_token_id=0, 875 | eos_token_id=2, 876 | model_type="clip_text_model", 877 | projection_dim=512, 878 | torch_dtype="float32", 879 | transformers_version="4.25.0.dev0", 880 | ) 881 | text_model = CLIPTextModel._from_config(cfg) 882 | info = text_model.load_state_dict(converted_text_encoder_checkpoint) 883 | else: 884 | converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) 885 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") 886 | info = text_model.load_state_dict(converted_text_encoder_checkpoint) 887 | 888 | 889 | return text_model, vae, unet 890 | 891 | 892 | def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): 893 | def convert_key(key): 894 | # position_idsの除去 895 | if ".position_ids" in key: 896 | return None 897 | 898 | # common 899 | key = key.replace("text_model.encoder.", "transformer.") 900 | key = key.replace("text_model.", "") 901 | if "layers" in key: 902 | # resblocks conversion 903 | key = key.replace(".layers.", ".resblocks.") 904 | if ".layer_norm" in key: 905 | key = key.replace(".layer_norm", ".ln_") 906 | elif ".mlp." in key: 907 | key = key.replace(".fc1.", ".c_fc.") 908 | key = key.replace(".fc2.", ".c_proj.") 909 | elif '.self_attn.out_proj' in key: 910 | key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") 911 | elif '.self_attn.' in key: 912 | key = None # 特殊なので後で処理する 913 | else: 914 | raise ValueError(f"unexpected key in DiffUsers model: {key}") 915 | elif '.position_embedding' in key: 916 | key = key.replace("embeddings.position_embedding.weight", "positional_embedding") 917 | elif '.token_embedding' in key: 918 | key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") 919 | elif 'final_layer_norm' in key: 920 | key = key.replace("final_layer_norm", "ln_final") 921 | return key 922 | 923 | keys = list(checkpoint.keys()) 924 | new_sd = {} 925 | for key in keys: 926 | new_key = convert_key(key) 927 | if new_key is None: 928 | continue 929 | new_sd[new_key] = checkpoint[key] 930 | 931 | # attnの変換 932 | for key in keys: 933 | if 'layers' in key and 'q_proj' in key: 934 | # 三つを結合 935 | key_q = key 936 | key_k = key.replace("q_proj", "k_proj") 937 | key_v = key.replace("q_proj", "v_proj") 938 | 939 | value_q = checkpoint[key_q] 940 | value_k = checkpoint[key_k] 941 | value_v = checkpoint[key_v] 942 | value = torch.cat([value_q, value_k, value_v]) 943 | 944 | new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") 945 | new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") 946 | new_sd[new_key] = value 947 | 948 | # 最後の層などを捏造するか 949 | if make_dummy_weights: 950 | 951 | keys = list(new_sd.keys()) 952 | for key in keys: 953 | if key.startswith("transformer.resblocks.22."): 954 | new_sd[key.replace(".22.", ".23.")] = new_sd[key] 955 | 956 | # Diffusersに含まれない重みを作っておく 957 | new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) 958 | new_sd['logit_scale'] = torch.tensor(1) 959 | 960 | return new_sd 961 | 962 | 963 | def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): 964 | if ckpt_path is not None: 965 | # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む 966 | checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path) 967 | state_dict = checkpoint["state_dict"] 968 | strict = True 969 | else: 970 | # 新しく作る 971 | checkpoint = {} 972 | state_dict = {} 973 | strict = False 974 | 975 | def update_sd(prefix, sd): 976 | for k, v in sd.items(): 977 | key = prefix + k 978 | assert not strict or key in state_dict, f"Illegal key in save SD: {key}" 979 | if save_dtype is not None: 980 | v = v.detach().clone().to("cpu").to(save_dtype) 981 | state_dict[key] = v 982 | 983 | # Convert the UNet model 984 | unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) 985 | update_sd("model.diffusion_model.", unet_state_dict) 986 | 987 | # Convert the text encoder model 988 | if v2: 989 | make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる 990 | text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) 991 | update_sd("cond_stage_model.model.", text_enc_dict) 992 | else: 993 | text_enc_dict = text_encoder.state_dict() 994 | update_sd("cond_stage_model.transformer.", text_enc_dict) 995 | 996 | # Convert the VAE 997 | if vae is not None: 998 | vae_dict = convert_vae_state_dict(vae.state_dict()) 999 | update_sd("first_stage_model.", vae_dict) 1000 | 1001 | # Put together new checkpoint 1002 | key_count = len(state_dict.keys()) 1003 | new_ckpt = {'state_dict': state_dict} 1004 | 1005 | if 'epoch' in checkpoint: 1006 | epochs += checkpoint['epoch'] 1007 | if 'global_step' in checkpoint: 1008 | steps += checkpoint['global_step'] 1009 | 1010 | new_ckpt['epoch'] = epochs 1011 | new_ckpt['global_step'] = steps 1012 | 1013 | torch.save(new_ckpt, output_file) 1014 | 1015 | return key_count 1016 | 1017 | 1018 | def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None): 1019 | if vae is None: 1020 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") 1021 | pipeline = StableDiffusionPipeline( 1022 | unet=unet, 1023 | text_encoder=text_encoder, 1024 | vae=vae, 1025 | scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"), 1026 | tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"), 1027 | ) 1028 | pipeline.save_pretrained(output_dir) 1029 | 1030 | 1031 | 1032 | def convert(args): 1033 | # 引数を確認する 1034 | load_dtype = torch.float16 if args.fp16 else None 1035 | 1036 | save_dtype = None 1037 | if args.fp16: 1038 | save_dtype = torch.float16 1039 | elif args.bf16: 1040 | save_dtype = torch.bfloat16 1041 | elif args.float: 1042 | save_dtype = torch.float 1043 | 1044 | is_load_ckpt = os.path.isfile(args.model_to_load) 1045 | is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0 1046 | 1047 | assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です" 1048 | assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です" 1049 | 1050 | # モデルを読み込む 1051 | msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) 1052 | 1053 | 1054 | if is_load_ckpt: 1055 | v2_model = args.v2 1056 | text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load) 1057 | else: 1058 | pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None) 1059 | text_encoder = pipe.text_encoder 1060 | vae = pipe.vae 1061 | unet = pipe.unet 1062 | 1063 | if args.v1 == args.v2: 1064 | # 自動判定する 1065 | v2_model = unet.config.cross_attention_dim == 1024 1066 | #print("checking model version: model is " + ('v2' if v2_model else 'v1')) 1067 | else: 1068 | v2_model = args.v1 1069 | 1070 | # 変換して保存する 1071 | msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" 1072 | 1073 | 1074 | if is_save_ckpt: 1075 | original_model = args.model_to_load if is_load_ckpt else None 1076 | key_count = save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet, 1077 | original_model, args.epoch, args.global_step, save_dtype, vae) 1078 | 1079 | else: 1080 | save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae) 1081 | 1082 | 1083 | 1084 | if __name__ == '__main__': 1085 | parser = argparse.ArgumentParser() 1086 | parser.add_argument("--v1", action='store_true', 1087 | help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む') 1088 | parser.add_argument("--v2", action='store_true', 1089 | help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む') 1090 | parser.add_argument("--fp16", action='store_true', 1091 | help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)') 1092 | parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)') 1093 | parser.add_argument("--float", action='store_true', 1094 | help='save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)') 1095 | parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値') 1096 | parser.add_argument("--global_step", type=int, default=0, 1097 | help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値') 1098 | parser.add_argument("--reference_model", type=str, default=None, 1099 | help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要") 1100 | 1101 | parser.add_argument("model_to_load", type=str, default=None, 1102 | help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ") 1103 | parser.add_argument("model_to_save", type=str, default=None, 1104 | help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存") 1105 | 1106 | args = parser.parse_args() 1107 | convert(args) 1108 | -------------------------------------------------------------------------------- /Dreambooth/det.py: -------------------------------------------------------------------------------- 1 | #Adapted from A1111 2 | import argparse 3 | import torch 4 | import open_clip 5 | import transformers.utils.hub 6 | from safetensors import safe_open 7 | import os 8 | import sys 9 | import wget 10 | from subprocess import call 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--MODEL_PATH", type=str) 14 | parser.add_argument("--from_safetensors", action='store_true') 15 | args = parser.parse_args() 16 | 17 | wget.download("https://github.com/TheLastBen/fast-stable-diffusion/raw/main/Dreambooth/ldm.zip") 18 | call('unzip ldm', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w')) 19 | call('rm ldm.zip', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w')) 20 | 21 | import ldm.modules.diffusionmodules.openaimodel 22 | import ldm.modules.encoders.modules 23 | 24 | class DisableInitialization: 25 | 26 | def __init__(self, disable_clip=True): 27 | self.replaced = [] 28 | self.disable_clip = disable_clip 29 | 30 | def replace(self, obj, field, func): 31 | original = getattr(obj, field, None) 32 | if original is None: 33 | return None 34 | 35 | self.replaced.append((obj, field, original)) 36 | setattr(obj, field, func) 37 | 38 | return original 39 | 40 | def __enter__(self): 41 | def do_nothing(*args, **kwargs): 42 | pass 43 | 44 | def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs): 45 | return self.create_model_and_transforms(*args, pretrained=None, **kwargs) 46 | 47 | def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): 48 | res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) 49 | res.name_or_path = pretrained_model_name_or_path 50 | return res 51 | 52 | def transformers_modeling_utils_load_pretrained_model(*args, **kwargs): 53 | args = args[0:3] + ('/', ) + args[4:] 54 | return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs) 55 | 56 | def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs): 57 | 58 | if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json': 59 | return None 60 | 61 | try: 62 | res = original(url, *args, local_files_only=True, **kwargs) 63 | if res is None: 64 | res = original(url, *args, local_files_only=False, **kwargs) 65 | return res 66 | except Exception as e: 67 | return original(url, *args, local_files_only=False, **kwargs) 68 | 69 | def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs): 70 | return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs) 71 | 72 | def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs): 73 | return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs) 74 | 75 | def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs): 76 | return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs) 77 | 78 | self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing) 79 | self.replace(torch.nn.init, '_no_grad_normal_', do_nothing) 80 | self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing) 81 | 82 | if self.disable_clip: 83 | self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) 84 | self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) 85 | self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) 86 | self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file) 87 | self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) 88 | self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) 89 | 90 | def __exit__(self, exc_type, exc_val, exc_tb): 91 | for obj, field, original in self.replaced: 92 | setattr(obj, field, original) 93 | 94 | self.replaced.clear() 95 | 96 | 97 | def vpar(state_dict): 98 | 99 | device = torch.device("cuda") 100 | 101 | with DisableInitialization(): 102 | unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( 103 | use_checkpoint=True, 104 | use_fp16=False, 105 | image_size=32, 106 | in_channels=4, 107 | out_channels=4, 108 | model_channels=320, 109 | attention_resolutions=[4, 2, 1], 110 | num_res_blocks=2, 111 | channel_mult=[1, 2, 4, 4], 112 | num_head_channels=64, 113 | use_spatial_transformer=True, 114 | use_linear_in_transformer=True, 115 | transformer_depth=1, 116 | context_dim=1024, 117 | legacy=False 118 | ) 119 | unet.eval() 120 | 121 | with torch.no_grad(): 122 | unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} 123 | unet.load_state_dict(unet_sd, strict=True) 124 | unet.to(device=device, dtype=torch.float) 125 | 126 | test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 127 | x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 128 | 129 | out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item() 130 | 131 | return out < -1 132 | 133 | 134 | def detect_version(sd): 135 | 136 | sys.stdout = open(os.devnull, 'w') 137 | 138 | sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) 139 | diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) 140 | 141 | if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: 142 | 143 | if vpar(sd): 144 | sys.stdout = sys.__stdout__ 145 | sd2_v=print("V2.1-768px") 146 | return sd2_v 147 | else: 148 | sys.stdout = sys.__stdout__ 149 | sd2=print("V2.1-512px") 150 | return sd2 151 | 152 | else: 153 | sys.stdout = sys.__stdout__ 154 | v1=print("1.5") 155 | return v1 156 | 157 | 158 | if args.from_safetensors: 159 | 160 | checkpoint = {} 161 | with safe_open(args.MODEL_PATH, framework="pt", device="cuda") as f: 162 | for key in f.keys(): 163 | checkpoint[key] = f.get_tensor(key) 164 | state_dict = checkpoint 165 | else: 166 | checkpoint = torch.load(args.MODEL_PATH, map_location="cuda") 167 | state_dict = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint 168 | 169 | detect_version(state_dict) 170 | 171 | call('rm -r ldm', shell=True, stdout=open('/dev/null', 'w'), stderr=open('/dev/null', 'w')) 172 | -------------------------------------------------------------------------------- /Dreambooth/ldm.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dreambooth/ldm.zip -------------------------------------------------------------------------------- /Dreambooth/model_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableDiffusionPipeline", 3 | "_diffusers_version": "0.6.0", 4 | "scheduler": [ 5 | "diffusers", 6 | "PNDMScheduler" 7 | ], 8 | "text_encoder": [ 9 | "transformers", 10 | "CLIPTextModel" 11 | ], 12 | "tokenizer": [ 13 | "transformers", 14 | "CLIPTokenizer" 15 | ], 16 | "unet": [ 17 | "diffusers", 18 | "UNet2DConditionModel" 19 | ], 20 | "vae": [ 21 | "diffusers", 22 | "AutoencoderKL" 23 | ] 24 | } 25 | -------------------------------------------------------------------------------- /Dreambooth/refmdlz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Strabo80/fast-stable-diffusion/70247ba41f3235337138b23f52b3d02fbbc92aa9/Dreambooth/refmdlz -------------------------------------------------------------------------------- /Dreambooth/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "PNDMScheduler", 3 | "_diffusers_version": "0.6.0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "num_train_timesteps": 1000, 8 | "set_alpha_to_one": false, 9 | "skip_prk_steps": true, 10 | "steps_offset": 1, 11 | "trained_betas": null 12 | 13 | } 14 | -------------------------------------------------------------------------------- /Dreambooth/smart_crop.py: -------------------------------------------------------------------------------- 1 | #Based on A1111 cropping script 2 | import cv2 3 | import requests 4 | import os 5 | from collections import defaultdict 6 | from math import log, sqrt 7 | import numpy as np 8 | from PIL import Image, ImageDraw 9 | 10 | GREEN = "#0F0" 11 | BLUE = "#00F" 12 | RED = "#F00" 13 | 14 | 15 | def crop_image(im, size): 16 | 17 | def focal_point(im, settings): 18 | corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else [] 19 | entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else [] 20 | face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else [] 21 | 22 | pois = [] 23 | 24 | weight_pref_total = 0 25 | if len(corner_points) > 0: 26 | weight_pref_total += settings.corner_points_weight 27 | if len(entropy_points) > 0: 28 | weight_pref_total += settings.entropy_points_weight 29 | if len(face_points) > 0: 30 | weight_pref_total += settings.face_points_weight 31 | 32 | corner_centroid = None 33 | if len(corner_points) > 0: 34 | corner_centroid = centroid(corner_points) 35 | corner_centroid.weight = settings.corner_points_weight / weight_pref_total 36 | pois.append(corner_centroid) 37 | 38 | entropy_centroid = None 39 | if len(entropy_points) > 0: 40 | entropy_centroid = centroid(entropy_points) 41 | entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total 42 | pois.append(entropy_centroid) 43 | 44 | face_centroid = None 45 | if len(face_points) > 0: 46 | face_centroid = centroid(face_points) 47 | face_centroid.weight = settings.face_points_weight / weight_pref_total 48 | pois.append(face_centroid) 49 | 50 | average_point = poi_average(pois, settings) 51 | 52 | return average_point 53 | 54 | 55 | def image_face_points(im, settings): 56 | 57 | np_im = np.array(im) 58 | gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY) 59 | 60 | tries = [ 61 | [ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ], 62 | [ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ], 63 | [ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ], 64 | [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ], 65 | [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ], 66 | [ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ], 67 | [ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ], 68 | [ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ] 69 | ] 70 | for t in tries: 71 | classifier = cv2.CascadeClassifier(t[0]) 72 | minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side 73 | try: 74 | faces = classifier.detectMultiScale(gray, scaleFactor=1.1, 75 | minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) 76 | except: 77 | continue 78 | 79 | if len(faces) > 0: 80 | rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] 81 | return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects] 82 | return [] 83 | 84 | 85 | def image_corner_points(im, settings): 86 | grayscale = im.convert("L") 87 | 88 | # naive attempt at preventing focal points from collecting at watermarks near the bottom 89 | gd = ImageDraw.Draw(grayscale) 90 | gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999") 91 | 92 | np_im = np.array(grayscale) 93 | 94 | points = cv2.goodFeaturesToTrack( 95 | np_im, 96 | maxCorners=100, 97 | qualityLevel=0.04, 98 | minDistance=min(grayscale.width, grayscale.height)*0.06, 99 | useHarrisDetector=False, 100 | ) 101 | 102 | if points is None: 103 | return [] 104 | 105 | focal_points = [] 106 | for point in points: 107 | x, y = point.ravel() 108 | focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points))) 109 | 110 | return focal_points 111 | 112 | 113 | def image_entropy_points(im, settings): 114 | landscape = im.height < im.width 115 | portrait = im.height > im.width 116 | if landscape: 117 | move_idx = [0, 2] 118 | move_max = im.size[0] 119 | elif portrait: 120 | move_idx = [1, 3] 121 | move_max = im.size[1] 122 | else: 123 | return [] 124 | 125 | e_max = 0 126 | crop_current = [0, 0, settings.crop_width, settings.crop_height] 127 | crop_best = crop_current 128 | while crop_current[move_idx[1]] < move_max: 129 | crop = im.crop(tuple(crop_current)) 130 | e = image_entropy(crop) 131 | 132 | if (e > e_max): 133 | e_max = e 134 | crop_best = list(crop_current) 135 | 136 | crop_current[move_idx[0]] += 4 137 | crop_current[move_idx[1]] += 4 138 | 139 | x_mid = int(crop_best[0] + settings.crop_width/2) 140 | y_mid = int(crop_best[1] + settings.crop_height/2) 141 | 142 | return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)] 143 | 144 | 145 | def image_entropy(im): 146 | # greyscale image entropy 147 | # band = np.asarray(im.convert("L")) 148 | band = np.asarray(im.convert("1"), dtype=np.uint8) 149 | hist, _ = np.histogram(band, bins=range(0, 256)) 150 | hist = hist[hist > 0] 151 | return -np.log2(hist / hist.sum()).sum() 152 | 153 | def centroid(pois): 154 | x = [poi.x for poi in pois] 155 | y = [poi.y for poi in pois] 156 | return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois)) 157 | 158 | 159 | def poi_average(pois, settings): 160 | weight = 0.0 161 | x = 0.0 162 | y = 0.0 163 | for poi in pois: 164 | weight += poi.weight 165 | x += poi.x * poi.weight 166 | y += poi.y * poi.weight 167 | avg_x = round(weight and x / weight) 168 | avg_y = round(weight and y / weight) 169 | 170 | return PointOfInterest(avg_x, avg_y) 171 | 172 | 173 | def is_landscape(w, h): 174 | return w > h 175 | 176 | 177 | def is_portrait(w, h): 178 | return h > w 179 | 180 | 181 | def is_square(w, h): 182 | return w == h 183 | 184 | 185 | class PointOfInterest: 186 | def __init__(self, x, y, weight=1.0, size=10): 187 | self.x = x 188 | self.y = y 189 | self.weight = weight 190 | self.size = size 191 | 192 | def bounding(self, size): 193 | return [ 194 | self.x - size//2, 195 | self.y - size//2, 196 | self.x + size//2, 197 | self.y + size//2 198 | ] 199 | 200 | class Settings: 201 | def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5): 202 | self.crop_width = crop_width 203 | self.crop_height = crop_height 204 | self.corner_points_weight = corner_points_weight 205 | self.entropy_points_weight = entropy_points_weight 206 | self.face_points_weight = face_points_weight 207 | 208 | settings = Settings( 209 | crop_width = size, 210 | crop_height = size, 211 | face_points_weight = 0.9, 212 | entropy_points_weight = 0.15, 213 | corner_points_weight = 0.5, 214 | ) 215 | 216 | scale_by = 1 217 | if is_landscape(im.width, im.height): 218 | scale_by = settings.crop_height / im.height 219 | elif is_portrait(im.width, im.height): 220 | scale_by = settings.crop_width / im.width 221 | elif is_square(im.width, im.height): 222 | if is_square(settings.crop_width, settings.crop_height): 223 | scale_by = settings.crop_width / im.width 224 | elif is_landscape(settings.crop_width, settings.crop_height): 225 | scale_by = settings.crop_width / im.width 226 | elif is_portrait(settings.crop_width, settings.crop_height): 227 | scale_by = settings.crop_height / im.height 228 | 229 | im = im.resize((int(im.width * scale_by), int(im.height * scale_by))) 230 | im_debug = im.copy() 231 | 232 | focus = focal_point(im_debug, settings) 233 | 234 | # take the focal point and turn it into crop coordinates that try to center over the focal 235 | # point but then get adjusted back into the frame 236 | y_half = int(settings.crop_height / 2) 237 | x_half = int(settings.crop_width / 2) 238 | 239 | x1 = focus.x - x_half 240 | if x1 < 0: 241 | x1 = 0 242 | elif x1 + settings.crop_width > im.width: 243 | x1 = im.width - settings.crop_width 244 | 245 | y1 = focus.y - y_half 246 | if y1 < 0: 247 | y1 = 0 248 | elif y1 + settings.crop_height > im.height: 249 | y1 = im.height - settings.crop_height 250 | 251 | x2 = x1 + settings.crop_width 252 | y2 = y1 + settings.crop_height 253 | 254 | crop = [x1, y1, x2, y2] 255 | 256 | results = [] 257 | 258 | results.append(im.crop(tuple(crop))) 259 | 260 | return results 261 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright © 2022 Ben 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fast-stable-diffusion Notebooks, AUTOMATIC1111 + DreamBooth 2 | Colab & Runpod & Paperspace adaptations AUTOMATIC1111 Webui and Dreambooth. 3 | 4 |