├── .github └── workflows │ ├── autotag.yaml │ ├── black.yml │ ├── publish-to-test-pypi.yml │ └── pull_request_template.md ├── .gitignore ├── LICENSE.txt ├── Makefile ├── README.md ├── merge_models.py ├── pyproject.toml ├── requirements.txt └── sd_meh ├── __init__.py ├── merge.py ├── merge_methods.py ├── model.py ├── presets.py ├── rebasin.py └── utils.py /.github/workflows/autotag.yaml: -------------------------------------------------------------------------------- 1 | name: Python 🐍 Auto Version Tag 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | 7 | jobs: 8 | tag: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | 13 | - name: Version tag 14 | uses: samamorgan/action-autotag-python@master 15 | 16 | with: 17 | path: sd_meh/__init__.py 18 | variable: __version__ 19 | github_token: ${{ secrets.AUTOTAG }} 20 | -------------------------------------------------------------------------------- /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: black 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | black: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - uses: psf/black@stable 11 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-test-pypi.yml: -------------------------------------------------------------------------------- 1 | name: publish sd-meh to pypi and testpypi 2 | on: 3 | push: 4 | tags: 5 | - "*" 6 | 7 | jobs: 8 | build-n-publish: 9 | name: build and publish sd-meh to pypi and testpypi 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: set up python 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: "3.x" 17 | - name: install pypa/build 18 | run: >- 19 | python -m 20 | pip install 21 | build 22 | --user 23 | - name: build a binary wheel and a src tarball 24 | run: >- 25 | python -m 26 | build 27 | --sdist 28 | --wheel 29 | --outdir dist/ 30 | . 31 | # - name: publish to testpypi 32 | # uses: pypa/gh-action-pypi-publish@release/v1 33 | # with: 34 | # password: ${{ secrets.TEST_PYPI_API_TOKEN }} 35 | # repository-url: https://test.pypi.org/legacy/ 36 | - name: publish to pypi 37 | if: startsWith(github.ref, 'refs/tags') 38 | uses: pypa/gh-action-pypi-publish@release/v1 39 | with: 40 | password: ${{ secrets.PYPI_API_TOKEN }} 41 | -------------------------------------------------------------------------------- /.github/workflows/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Describe your changes 2 | 3 | ## Issue ticket number and link 4 | 5 | ## Checklist before requesting a review 6 | - [ ] I have performed a self-review of my code 7 | - [ ] The PR title follows the [conventional commits specs] 8 | 9 | ##### In case of `[BREAKING CHANGE]`, `[feat]` or `[fix]` 10 | - [ ] I have updated version in `sd_meh/__init.py__` following [semantic versioning](https://semver.org) rules 11 | - [ ] I have updated version in `pyproject.toml` 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | **/__pycache__/* 3 | dist/* 4 | **/*.png 5 | *.safetensors 6 | *.ckpt 7 | .idea/ 8 | .mypy_cache/* 9 | **/*.swp 10 | **/*.swo 11 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 s1dlx 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: format 2 | format: 3 | isort . 4 | autoflake --remove-all-unused-imports -i -r --exclude __init__.py . 5 | black . -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sd-meh 2 | 3 | [![PyPI version](https://badge.fury.io/py/sd-meh.svg)](https://badge.fury.io/py/sd-meh) 4 | 5 | [![](https://dcbadge.vercel.app/api/server/EZJuBfNVHh)](https://discord.gg/EZJuBfNVHh) 6 | 7 | 8 | The merging execution helper (meh) is a python module for stable diffusion models merging. 9 | This repository will never contain code for a webui extension. 10 | This is because the aim is to have a GUI agnostic merging engine that can be reused in multiple extensions. 11 | 12 | You can install the module as 13 | 14 | ``` 15 | pip install sd-meh 16 | ``` 17 | 18 | and then use it in your extension as 19 | 20 | ```python 21 | from sd_meh.merge import merge_models 22 | 23 | merged_model = merge_models(models, weights, bases, merge_mode, precision) 24 | ``` 25 | 26 | You can have a look at the provided `merge_models.py` cli for an example on how to use the function. Run `python3 merge_models.py --help` for a list of the available arguments. 27 | 28 | [Join](https://discord.gg/EZJuBfNVHh) our discord server for discussion and features/bugfix requests 29 | 30 | ## Changelog 31 | 32 | ### 0.9.1 ... 0.9.3 33 | - bugfixes 34 | - support for pix2pix and inpainting models 35 | 36 | ### 0.8.0 37 | - add `-bwpab, --block_weights_preset_alpha_b"` and `-pal, --presets_alpha_lambda` for presets interpolation (same for `beta`) 38 | - add `-ll, --logging_level`, default to `INFO` 39 | 40 | ### 0.7.0 41 | - add `-bwpa, --block_weights_preset_alpha` and `-bwpb, --block_weights_preset_beta` to use pre-defined merging weights. Have a look at the [wiki](https://github.com/s1dlx/meh/wiki/Presets) for all the presets 42 | - add `-wd, --work_device` 43 | - add `-pr, --prune` 44 | - add `-j, --threads` 45 | 46 | 47 | ## DEV 48 | 49 | PRs are welcome for both new features and bug fixes. 50 | 51 | Please make sure you format the code with black (you can `make format`) before submitting a PR. 52 | 53 | ### You want to add a `feature` 54 | 55 | - open a `feat:` PR merging to `dev` branch, not `main` 56 | - *do not* update version numbers 57 | - ask for a review 58 | 59 | ### You want to make a bug `fix` 60 | 61 | - open a `fix:` PR mergin to `main` 62 | - update version number in `pyproject.toml` and `sd_meh/__init__.py` 63 | - ask for a review 64 | -------------------------------------------------------------------------------- /merge_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import logging 3 | 4 | import click 5 | 6 | from sd_meh.merge import merge_models, save_model 7 | from sd_meh.presets import BLOCK_WEIGHTS_PRESETS 8 | from sd_meh.utils import MERGE_METHODS, weights_and_bases 9 | 10 | 11 | @click.command() 12 | @click.option("-a", "--model_a", "model_a", type=str) 13 | @click.option("-b", "--model_b", "model_b", type=str) 14 | @click.option("-c", "--model_c", "model_c", default=None, type=str) 15 | @click.option( 16 | "-m", 17 | "--merging_method", 18 | "merge_mode", 19 | type=click.Choice(list(MERGE_METHODS.keys()), case_sensitive=False), 20 | ) 21 | @click.option("-wc", "--weights_clip", "weights_clip", is_flag=True) 22 | @click.option("-p", "--precision", "precision", type=int, default=16) 23 | @click.option("-o", "--output_path", "output_path", type=str, default="model_out") 24 | @click.option( 25 | "-f", 26 | "--output_format", 27 | "output_format", 28 | type=click.Choice(["safetensors", "ckpt"], case_sensitive=False), 29 | ) 30 | @click.option("-wa", "--weights_alpha", "weights_alpha", type=str, default=None) 31 | @click.option("-ba", "--base_alpha", "base_alpha", type=float, default=0.0) 32 | @click.option("-wb", "--weights_beta", "weights_beta", type=str, default=None) 33 | @click.option("-bb", "--base_beta", "base_beta", type=float, default=0.0) 34 | @click.option("-rb", "--re_basin", "re_basin", is_flag=True) 35 | @click.option( 36 | "-rbi", "--re_basin_iterations", "re_basin_iterations", type=int, default=1 37 | ) 38 | @click.option( 39 | "-d", 40 | "--device", 41 | "device", 42 | type=click.Choice( 43 | ["cpu", "cuda"], 44 | case_sensitive=False, 45 | ), 46 | default="cpu", 47 | ) 48 | @click.option( 49 | "-wd", 50 | "--work_device", 51 | "work_device", 52 | type=click.Choice( 53 | ["cpu", "cuda"], 54 | case_sensitive=False, 55 | ), 56 | default=None, 57 | ) 58 | @click.option("-pr", "--prune", "prune", is_flag=True) 59 | @click.option( 60 | "-bwpa", 61 | "--block_weights_preset_alpha", 62 | "block_weights_preset_alpha", 63 | type=click.Choice(list(BLOCK_WEIGHTS_PRESETS.keys()), case_sensitive=False), 64 | default=None, 65 | ) 66 | @click.option( 67 | "-bwpb", 68 | "--block_weights_preset_beta", 69 | "block_weights_preset_beta", 70 | type=click.Choice(list(BLOCK_WEIGHTS_PRESETS.keys()), case_sensitive=False), 71 | default=None, 72 | ) 73 | @click.option( 74 | "-j", 75 | "--threads", 76 | "threads", 77 | type=int, 78 | default=1, 79 | ) 80 | @click.option( 81 | "-bwpab", 82 | "--block_weights_preset_alpha_b", 83 | "block_weights_preset_alpha_b", 84 | type=click.Choice(list(BLOCK_WEIGHTS_PRESETS.keys()), case_sensitive=False), 85 | default=None, 86 | ) 87 | @click.option( 88 | "-bwpbb", 89 | "--block_weights_preset_beta_b", 90 | "block_weights_preset_beta_b", 91 | type=click.Choice(list(BLOCK_WEIGHTS_PRESETS.keys()), case_sensitive=False), 92 | default=None, 93 | ) 94 | @click.option( 95 | "-pal", 96 | "--presets_alpha_lambda", 97 | "presets_alpha_lambda", 98 | type=float, 99 | default=None, 100 | ) 101 | @click.option( 102 | "-pbl", 103 | "--presets_beta_lambda", 104 | "presets_beta_lambda", 105 | type=float, 106 | default=None, 107 | ) 108 | @click.option( 109 | "-ll", 110 | "--logging_level", 111 | "logging_level", 112 | type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"], case_sensitive=False), 113 | default="INFO", 114 | ) 115 | def main( 116 | model_a, 117 | model_b, 118 | model_c, 119 | merge_mode, 120 | weights_clip, 121 | precision, 122 | output_path, 123 | output_format, 124 | weights_alpha, 125 | base_alpha, 126 | weights_beta, 127 | base_beta, 128 | re_basin, 129 | re_basin_iterations, 130 | device, 131 | work_device, 132 | prune, 133 | block_weights_preset_alpha, 134 | block_weights_preset_beta, 135 | threads, 136 | block_weights_preset_alpha_b, 137 | block_weights_preset_beta_b, 138 | presets_alpha_lambda, 139 | presets_beta_lambda, 140 | logging_level, 141 | ): 142 | if logging_level: 143 | logging.basicConfig(format="%(levelname)s: %(message)s", level=logging_level) 144 | 145 | models = {"model_a": model_a, "model_b": model_b} 146 | if model_c: 147 | models["model_c"] = model_c 148 | 149 | weights, bases = weights_and_bases( 150 | merge_mode, 151 | weights_alpha, 152 | base_alpha, 153 | block_weights_preset_alpha, 154 | weights_beta, 155 | base_beta, 156 | block_weights_preset_beta, 157 | block_weights_preset_alpha_b, 158 | block_weights_preset_beta_b, 159 | presets_alpha_lambda, 160 | presets_beta_lambda, 161 | ) 162 | 163 | merged = merge_models( 164 | models, 165 | weights, 166 | bases, 167 | merge_mode, 168 | precision, 169 | weights_clip, 170 | re_basin, 171 | re_basin_iterations, 172 | device, 173 | work_device, 174 | prune, 175 | threads, 176 | ) 177 | 178 | save_model(merged, output_path, output_format) 179 | 180 | 181 | if __name__ == "__main__": 182 | main() 183 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "sd-meh" 3 | version = "0.9.5" 4 | description = "stable diffusion merging execution helper" 5 | authors = ["s1dlx "] 6 | license = "MIT" 7 | readme = "README.md" 8 | packages = [{include = "sd_meh"}] 9 | repository = "https://github.com/s1dlx/meh" 10 | 11 | [tool.poetry.dependencies] 12 | python = "^3.10" 13 | 14 | 15 | [build-system] 16 | requires = ["poetry-core"] 17 | build-backend = "poetry.core.masonry.api" 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click 2 | numpy 3 | safetensors 4 | torch 5 | tqdm 6 | tensordict 7 | scipy 8 | -------------------------------------------------------------------------------- /sd_meh/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.9.5" 2 | -------------------------------------------------------------------------------- /sd_meh/merge.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | import os 4 | import re 5 | from concurrent.futures import ThreadPoolExecutor 6 | from contextlib import contextmanager 7 | from pathlib import Path 8 | from typing import Dict, Optional, Tuple 9 | 10 | import safetensors.torch 11 | import torch 12 | from tqdm import tqdm 13 | 14 | from sd_meh import merge_methods 15 | from sd_meh.model import SDModel 16 | from sd_meh.rebasin import ( 17 | apply_permutation, 18 | sdunet_permutation_spec, 19 | step_weights_and_bases, 20 | update_model_a, 21 | weight_matching, 22 | ) 23 | 24 | logging.getLogger("sd_meh").addHandler(logging.NullHandler()) 25 | MAX_TOKENS = 77 26 | NUM_INPUT_BLOCKS = 12 27 | NUM_MID_BLOCK = 1 28 | NUM_OUTPUT_BLOCKS = 12 29 | NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS 30 | 31 | KEY_POSITION_IDS = ".".join( 32 | [ 33 | "cond_stage_model", 34 | "transformer", 35 | "text_model", 36 | "embeddings", 37 | "position_ids", 38 | ] 39 | ) 40 | 41 | 42 | NAI_KEYS = { 43 | "cond_stage_model.transformer.embeddings.": "cond_stage_model.transformer.text_model.embeddings.", 44 | "cond_stage_model.transformer.encoder.": "cond_stage_model.transformer.text_model.encoder.", 45 | "cond_stage_model.transformer.final_layer_norm.": "cond_stage_model.transformer.text_model.final_layer_norm.", 46 | } 47 | 48 | 49 | def fix_clip(model: Dict) -> Dict: 50 | if KEY_POSITION_IDS in model.keys(): 51 | model[KEY_POSITION_IDS] = torch.tensor( 52 | [list(range(MAX_TOKENS))], 53 | dtype=torch.int64, 54 | device=model[KEY_POSITION_IDS].device, 55 | ) 56 | 57 | return model 58 | 59 | 60 | def fix_key(model: Dict, key: str) -> Dict: 61 | for nk in NAI_KEYS: 62 | if key.startswith(nk): 63 | model[key.replace(nk, NAI_KEYS[nk])] = model[key] 64 | del model[key] 65 | 66 | return model 67 | 68 | 69 | # https://github.com/j4ded/sdweb-merge-block-weighted-gui/blob/master/scripts/mbw/merge_block_weighted.py#L115 70 | def fix_model(model: Dict) -> Dict: 71 | for k in model.keys(): 72 | model = fix_key(model, k) 73 | return fix_clip(model) 74 | 75 | 76 | def load_sd_model(model: os.PathLike | str, device: str = "cpu") -> Dict: 77 | if isinstance(model, str): 78 | model = Path(model) 79 | 80 | return SDModel(model, device).load_model() 81 | 82 | 83 | def prune_sd_model(model: Dict) -> Dict: 84 | keys = list(model.keys()) 85 | for k in keys: 86 | if ( 87 | not k.startswith("model.diffusion_model.") 88 | and not k.startswith("first_stage_model.") 89 | and not k.startswith("cond_stage_model.") 90 | ): 91 | del model[k] 92 | return model 93 | 94 | 95 | def restore_sd_model(original_model: Dict, merged_model: Dict) -> Dict: 96 | for k in original_model: 97 | if k not in merged_model: 98 | merged_model[k] = original_model[k] 99 | return merged_model 100 | 101 | 102 | def log_vram(txt=""): 103 | alloc = torch.cuda.memory_allocated(0) 104 | logging.debug(f"{txt} VRAM: {alloc*1e-9:5.3f}GB") 105 | 106 | 107 | def load_thetas( 108 | models: Dict[str, os.PathLike | str], 109 | prune: bool, 110 | device: str, 111 | precision: int, 112 | ) -> Dict: 113 | log_vram("before loading models") 114 | if prune: 115 | thetas = {k: prune_sd_model(load_sd_model(m, "cpu")) for k, m in models.items()} 116 | else: 117 | thetas = {k: load_sd_model(m, device) for k, m in models.items()} 118 | 119 | if device == "cuda": 120 | for model_key, model in thetas.items(): 121 | for key, block in model.items(): 122 | if precision == 16: 123 | thetas[model_key].update({key: block.to(device).half()}) 124 | else: 125 | thetas[model_key].update({key: block.to(device)}) 126 | 127 | log_vram("models loaded") 128 | return thetas 129 | 130 | 131 | def merge_models( 132 | models: Dict[str, os.PathLike | str], 133 | weights: Dict, 134 | bases: Dict, 135 | merge_mode: str, 136 | precision: int = 16, 137 | weights_clip: bool = False, 138 | re_basin: bool = False, 139 | iterations: int = 1, 140 | device: str = "cpu", 141 | work_device: Optional[str] = None, 142 | prune: bool = False, 143 | threads: int = 1, 144 | ) -> Dict: 145 | thetas = load_thetas(models, prune, device, precision) 146 | 147 | logging.info(f"start merging with {merge_mode} method") 148 | if re_basin: 149 | merged = rebasin_merge( 150 | thetas, 151 | weights, 152 | bases, 153 | merge_mode, 154 | precision=precision, 155 | weights_clip=weights_clip, 156 | iterations=iterations, 157 | device=device, 158 | work_device=work_device, 159 | threads=threads, 160 | ) 161 | else: 162 | merged = simple_merge( 163 | thetas, 164 | weights, 165 | bases, 166 | merge_mode, 167 | precision=precision, 168 | weights_clip=weights_clip, 169 | device=device, 170 | work_device=work_device, 171 | threads=threads, 172 | ) 173 | 174 | return un_prune_model(merged, thetas, models, device, prune, precision) 175 | 176 | 177 | def un_prune_model( 178 | merged: Dict, 179 | thetas: Dict, 180 | models: Dict, 181 | device: str, 182 | prune: bool, 183 | precision: int, 184 | ) -> Dict: 185 | if prune: 186 | logging.info("Un-pruning merged model") 187 | del thetas 188 | gc.collect() 189 | log_vram("remove thetas") 190 | original_a = load_sd_model(models["model_a"], device) 191 | for key in tqdm(original_a.keys(), desc="un-prune model a"): 192 | if KEY_POSITION_IDS in key: 193 | continue 194 | if "model" in key and key not in merged.keys(): 195 | merged.update({key: original_a[key]}) 196 | if precision == 16: 197 | merged.update({key: merged[key].half()}) 198 | del original_a 199 | gc.collect() 200 | log_vram("remove original_a") 201 | original_b = load_sd_model(models["model_b"], device) 202 | for key in tqdm(original_b.keys(), desc="un-prune model b"): 203 | if KEY_POSITION_IDS in key: 204 | continue 205 | if "model" in key and key not in merged.keys(): 206 | merged.update({key: original_b[key]}) 207 | if precision == 16: 208 | merged.update({key: merged[key].half()}) 209 | del original_b 210 | 211 | return fix_model(merged) 212 | 213 | 214 | def simple_merge( 215 | thetas: Dict[str, Dict], 216 | weights: Dict, 217 | bases: Dict, 218 | merge_mode: str, 219 | precision: int = 16, 220 | weights_clip: bool = False, 221 | device: str = "cpu", 222 | work_device: Optional[str] = None, 223 | threads: int = 1, 224 | ) -> Dict: 225 | futures = [] 226 | with tqdm(thetas["model_a"].keys(), desc="stage 1") as progress: 227 | with ThreadPoolExecutor(max_workers=threads) as executor: 228 | for key in thetas["model_a"].keys(): 229 | future = executor.submit( 230 | simple_merge_key, 231 | progress, 232 | key, 233 | thetas, 234 | weights, 235 | bases, 236 | merge_mode, 237 | precision, 238 | weights_clip, 239 | device, 240 | work_device, 241 | ) 242 | futures.append(future) 243 | 244 | for res in futures: 245 | res.result() 246 | 247 | log_vram("after stage 1") 248 | 249 | for key in tqdm(thetas["model_b"].keys(), desc="stage 2"): 250 | if KEY_POSITION_IDS in key: 251 | continue 252 | if "model" in key and key not in thetas["model_a"].keys(): 253 | thetas["model_a"].update({key: thetas["model_b"][key]}) 254 | if precision == 16: 255 | thetas["model_a"].update({key: thetas["model_a"][key].half()}) 256 | 257 | log_vram("after stage 2") 258 | 259 | return fix_model(thetas["model_a"]) 260 | 261 | 262 | def rebasin_merge( 263 | thetas: Dict[str, os.PathLike | str], 264 | weights: Dict, 265 | bases: Dict, 266 | merge_mode: str, 267 | precision: int = 16, 268 | weights_clip: bool = False, 269 | iterations: int = 1, 270 | device="cpu", 271 | work_device=None, 272 | threads: int = 1, 273 | ): 274 | # WARNING: not sure how this does when 3 models are involved... 275 | 276 | model_a = thetas["model_a"].clone() 277 | perm_spec = sdunet_permutation_spec() 278 | 279 | logging.info("Init rebasin iterations") 280 | for it in range(iterations): 281 | logging.info(f"Rebasin iteration {it}") 282 | log_vram(f"{it} iteration start") 283 | new_weights, new_bases = step_weights_and_bases( 284 | weights, 285 | bases, 286 | it, 287 | iterations, 288 | ) 289 | log_vram("weights & bases, before simple merge") 290 | 291 | # normal block merge we already know and love 292 | thetas["model_a"] = simple_merge( 293 | thetas, 294 | new_weights, 295 | new_bases, 296 | merge_mode, 297 | precision, 298 | False, 299 | device, 300 | work_device, 301 | threads, 302 | ) 303 | 304 | log_vram("simple merge done") 305 | 306 | # find permutations 307 | perm_1, y = weight_matching( 308 | perm_spec, 309 | model_a, 310 | thetas["model_a"], 311 | max_iter=it, 312 | init_perm=None, 313 | usefp16=precision == 16, 314 | device=device, 315 | ) 316 | 317 | log_vram("weight matching #1 done") 318 | 319 | thetas["model_a"] = apply_permutation(perm_spec, perm_1, thetas["model_a"]) 320 | 321 | log_vram("apply perm 1 done") 322 | 323 | perm_2, z = weight_matching( 324 | perm_spec, 325 | thetas["model_b"], 326 | thetas["model_a"], 327 | max_iter=it, 328 | init_perm=None, 329 | usefp16=precision == 16, 330 | device=device, 331 | ) 332 | 333 | log_vram("weight matching #2 done") 334 | 335 | new_alpha = torch.nn.functional.normalize( 336 | torch.sigmoid(torch.Tensor([y, z])), p=1, dim=0 337 | ).tolist()[0] 338 | thetas["model_a"] = update_model_a( 339 | perm_spec, perm_2, thetas["model_a"], new_alpha 340 | ) 341 | 342 | log_vram("model a updated") 343 | 344 | if weights_clip: 345 | clip_thetas = thetas.copy() 346 | clip_thetas["model_a"] = model_a 347 | thetas["model_a"] = clip_weights(thetas, thetas["model_a"]) 348 | 349 | return thetas["model_a"] 350 | 351 | 352 | def simple_merge_key(progress, key, thetas, *args, **kwargs): 353 | with merge_key_context(key, thetas, *args, **kwargs) as result: 354 | if result is not None: 355 | thetas["model_a"].update({key: result.detach().clone()}) 356 | 357 | progress.update() 358 | 359 | 360 | def merge_key( 361 | key: str, 362 | thetas: Dict, 363 | weights: Dict, 364 | bases: Dict, 365 | merge_mode: str, 366 | precision: int = 16, 367 | weights_clip: bool = False, 368 | device: str = "cpu", 369 | work_device: Optional[str] = None, 370 | ) -> Optional[Tuple[str, Dict]]: 371 | if work_device is None: 372 | work_device = device 373 | 374 | if KEY_POSITION_IDS in key: 375 | return 376 | 377 | for theta in thetas.values(): 378 | if key not in theta.keys(): 379 | return 380 | 381 | if "model" in key: 382 | current_bases = bases 383 | 384 | if "model.diffusion_model." in key: 385 | weight_index = -1 386 | 387 | re_inp = re.compile(r"\.input_blocks\.(\d+)\.") # 12 388 | re_mid = re.compile(r"\.middle_block\.(\d+)\.") # 1 389 | re_out = re.compile(r"\.output_blocks\.(\d+)\.") # 12 390 | 391 | if "time_embed" in key: 392 | weight_index = 0 # before input blocks 393 | elif ".out." in key: 394 | weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks 395 | elif m := re_inp.search(key): 396 | weight_index = int(m.groups()[0]) 397 | elif re_mid.search(key): 398 | weight_index = NUM_INPUT_BLOCKS 399 | elif m := re_out.search(key): 400 | weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + int(m.groups()[0]) 401 | 402 | if weight_index >= NUM_TOTAL_BLOCKS: 403 | raise ValueError(f"illegal block index {key}") 404 | 405 | if weight_index >= 0: 406 | current_bases = {k: w[weight_index] for k, w in weights.items()} 407 | 408 | try: 409 | merge_method = getattr(merge_methods, merge_mode) 410 | except AttributeError as e: 411 | raise ValueError(f"{merge_mode} not implemented, aborting merge!") from e 412 | 413 | merge_args = get_merge_method_args(current_bases, thetas, key, work_device) 414 | 415 | # dealing wiht pix2pix and inpainting models 416 | if (a_size := merge_args["a"].size()) != (b_size := merge_args["b"].size()): 417 | if a_size[1] > b_size[1]: 418 | merged_key = merge_args["a"] 419 | else: 420 | merged_key = merge_args["b"] 421 | else: 422 | merged_key = merge_method(**merge_args).to(device) 423 | 424 | if weights_clip: 425 | merged_key = clip_weights_key(thetas, merged_key, key) 426 | 427 | if precision == 16: 428 | merged_key = merged_key.half() 429 | 430 | return merged_key 431 | 432 | 433 | def clip_weights(thetas, merged): 434 | for k in thetas["model_a"].keys(): 435 | if k in thetas["model_b"].keys(): 436 | merged.update({k: clip_weights_key(thetas, merged[k], k)}) 437 | return merged 438 | 439 | 440 | def clip_weights_key(thetas, merged_weights, key): 441 | t0 = thetas["model_a"][key] 442 | t1 = thetas["model_b"][key] 443 | maximums = torch.maximum(t0, t1) 444 | minimums = torch.minimum(t0, t1) 445 | return torch.minimum(torch.maximum(merged_weights, minimums), maximums) 446 | 447 | 448 | @contextmanager 449 | def merge_key_context(*args, **kwargs): 450 | result = merge_key(*args, **kwargs) 451 | try: 452 | yield result 453 | finally: 454 | if result is not None: 455 | del result 456 | 457 | 458 | def get_merge_method_args( 459 | current_bases: Dict, 460 | thetas: Dict, 461 | key: str, 462 | work_device: str, 463 | ) -> Dict: 464 | merge_method_args = { 465 | "a": thetas["model_a"][key].to(work_device), 466 | "b": thetas["model_b"][key].to(work_device), 467 | **current_bases, 468 | } 469 | 470 | if "model_c" in thetas: 471 | merge_method_args["c"] = thetas["model_c"][key].to(work_device) 472 | 473 | return merge_method_args 474 | 475 | 476 | def save_model(model, output_file, file_format) -> None: 477 | logging.info(f"Saving {output_file}") 478 | if file_format == "safetensors": 479 | safetensors.torch.save_file( 480 | model if type(model) == dict else model.to_dict(), 481 | f"{output_file}.safetensors", 482 | metadata={"format": "pt"}, 483 | ) 484 | else: 485 | torch.save({"state_dict": model}, f"{output_file}.ckpt") 486 | -------------------------------------------------------------------------------- /sd_meh/merge_methods.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | __all__ = [ 8 | "weighted_sum", 9 | "weighted_subtraction", 10 | "tensor_sum", 11 | "add_difference", 12 | "sum_twice", 13 | "triple_sum", 14 | "euclidean_add_difference", 15 | "multiply_difference", 16 | "top_k_tensor_sum", 17 | "similarity_add_difference", 18 | "distribution_crossover", 19 | "ties_add_difference", 20 | ] 21 | 22 | 23 | EPSILON = 1e-10 # Define a small constant EPSILON to prevent division by zero 24 | 25 | 26 | def weighted_sum(a: Tensor, b: Tensor, alpha: float, **kwargs) -> Tensor: 27 | return (1 - alpha) * a + alpha * b 28 | 29 | 30 | def weighted_subtraction( 31 | a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs 32 | ) -> Tensor: 33 | # Adjust beta if both alpha and beta are 1.0 to avoid division by zero 34 | if alpha == 1.0 and beta == 1.0: 35 | beta -= EPSILON 36 | 37 | return (a - alpha * beta * b) / (1 - alpha * beta) 38 | 39 | 40 | def tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: 41 | if alpha + beta <= 1: 42 | tt = a.clone() 43 | talphas = int(a.shape[0] * beta) 44 | talphae = int(a.shape[0] * (alpha + beta)) 45 | tt[talphas:talphae] = b[talphas:talphae].clone() 46 | else: 47 | talphas = int(a.shape[0] * (alpha + beta - 1)) 48 | talphae = int(a.shape[0] * beta) 49 | tt = b.clone() 50 | tt[talphas:talphae] = a[talphas:talphae].clone() 51 | return tt 52 | 53 | 54 | def add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor: 55 | return a + alpha * (b - c) 56 | 57 | 58 | def sum_twice( 59 | a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs 60 | ) -> Tensor: 61 | return (1 - beta) * ((1 - alpha) * a + alpha * b) + beta * c 62 | 63 | 64 | def triple_sum( 65 | a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs 66 | ) -> Tensor: 67 | return (1 - alpha - beta) * a + alpha * b + beta * c 68 | 69 | 70 | def euclidean_add_difference( 71 | a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs 72 | ) -> Tensor: 73 | a_diff = a.float() - c.float() 74 | b_diff = b.float() - c.float() 75 | a_diff = torch.nan_to_num(a_diff / torch.linalg.norm(a_diff)) 76 | b_diff = torch.nan_to_num(b_diff / torch.linalg.norm(b_diff)) 77 | 78 | distance = (1 - alpha) * a_diff**2 + alpha * b_diff**2 79 | distance = torch.sqrt(distance) 80 | sum_diff = weighted_sum(a.float(), b.float(), alpha) - c.float() 81 | distance = torch.copysign(distance, sum_diff) 82 | 83 | target_norm = torch.linalg.norm(sum_diff) 84 | return c + distance / torch.linalg.norm(distance) * target_norm 85 | 86 | 87 | def multiply_difference( 88 | a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs 89 | ) -> Tensor: 90 | diff_a = torch.pow(torch.abs(a.float() - c), (1 - alpha)) 91 | diff_b = torch.pow(torch.abs(b.float() - c), alpha) 92 | difference = torch.copysign(diff_a * diff_b, weighted_sum(a, b, beta) - c) 93 | return c + difference.to(c.dtype) 94 | 95 | 96 | def top_k_tensor_sum( 97 | a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs 98 | ) -> Tensor: 99 | a_flat = torch.flatten(a) 100 | a_dist = torch.msort(a_flat) 101 | b_indices = torch.argsort(torch.flatten(b), stable=True) 102 | redist_indices = torch.argsort(b_indices) 103 | 104 | start_i, end_i, region_is_inverted = ratio_to_region(alpha, beta, torch.numel(a)) 105 | start_top_k = kth_abs_value(a_dist, start_i) 106 | end_top_k = kth_abs_value(a_dist, end_i) 107 | 108 | indices_mask = (start_top_k < torch.abs(a_dist)) & (torch.abs(a_dist) <= end_top_k) 109 | if region_is_inverted: 110 | indices_mask = ~indices_mask 111 | indices_mask = torch.gather(indices_mask.float(), 0, redist_indices) 112 | 113 | a_redist = torch.gather(a_dist, 0, redist_indices) 114 | a_redist = (1 - indices_mask) * a_flat + indices_mask * a_redist 115 | return a_redist.reshape_as(a) 116 | 117 | 118 | def kth_abs_value(a: Tensor, k: int) -> Tensor: 119 | if k <= 0: 120 | return torch.tensor(-1, device=a.device) 121 | else: 122 | return torch.kthvalue(torch.abs(a.float()), k)[0] 123 | 124 | 125 | def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]: 126 | if width < 0: 127 | offset += width 128 | width = -width 129 | width = min(width, 1) 130 | 131 | if offset < 0: 132 | offset = 1 + offset - int(offset) 133 | offset = math.fmod(offset, 1.0) 134 | 135 | if width + offset <= 1: 136 | inverted = False 137 | start = offset * n 138 | end = (width + offset) * n 139 | else: 140 | inverted = True 141 | start = (width + offset - 1) * n 142 | end = offset * n 143 | 144 | return round(start), round(end), inverted 145 | 146 | 147 | def similarity_add_difference( 148 | a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs 149 | ) -> Tensor: 150 | threshold = torch.maximum(torch.abs(a), torch.abs(b)) 151 | similarity = ((a * b / threshold**2) + 1) / 2 152 | similarity = torch.nan_to_num(similarity * beta, nan=beta) 153 | 154 | ab_diff = a + alpha * (b - c) 155 | ab_sum = (1 - alpha / 2) * a + (alpha / 2) * b 156 | return (1 - similarity) * ab_diff + similarity * ab_sum 157 | 158 | 159 | def distribution_crossover( 160 | a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs 161 | ): 162 | if a.shape == (): 163 | return alpha * a + (1 - alpha) * b 164 | 165 | c_indices = torch.argsort(torch.flatten(c)) 166 | a_dist = torch.gather(torch.flatten(a), 0, c_indices) 167 | b_dist = torch.gather(torch.flatten(b), 0, c_indices) 168 | 169 | a_dft = torch.fft.rfft(a_dist.float()) 170 | b_dft = torch.fft.rfft(b_dist.float()) 171 | 172 | dft_filter = torch.arange(0, torch.numel(a_dft), device=a_dft.device).float() 173 | dft_filter /= torch.numel(a_dft) 174 | if beta > EPSILON: 175 | dft_filter = (dft_filter - alpha) / beta + 1 / 2 176 | dft_filter = torch.clamp(dft_filter, 0.0, 1.0) 177 | else: 178 | dft_filter = (dft_filter >= alpha).float() 179 | 180 | x_dft = (1 - dft_filter) * a_dft + dft_filter * b_dft 181 | x_dist = torch.fft.irfft(x_dft, a_dist.shape[0]) 182 | x_values = torch.gather(x_dist, 0, torch.argsort(c_indices)) 183 | return x_values.reshape_as(a) 184 | 185 | 186 | def ties_add_difference( 187 | a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs 188 | ) -> Tensor: 189 | deltas = [] 190 | signs = [] 191 | for m in [a, b]: 192 | deltas.append(filter_top_k(m - c, beta)) 193 | signs.append(torch.sign(deltas[-1])) 194 | 195 | signs = torch.stack(signs, dim=0) 196 | final_sign = torch.sign(torch.sum(signs, dim=0)) 197 | delta_filters = (signs == final_sign).float() 198 | 199 | res = torch.zeros_like(c, device=c.device) 200 | for delta_filter, delta in zip(delta_filters, deltas): 201 | res += delta_filter * delta 202 | 203 | param_count = torch.sum(delta_filters, dim=0) 204 | return c + alpha * torch.nan_to_num(res / param_count) 205 | 206 | 207 | def filter_top_k(a: Tensor, k: float): 208 | k = max(int((1 - k) * torch.numel(a)), 1) 209 | k_value, _ = torch.kthvalue(torch.abs(a.flatten()).float(), k) 210 | top_k_filter = (torch.abs(a) >= k_value).float() 211 | return a * top_k_filter 212 | -------------------------------------------------------------------------------- /sd_meh/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from dataclasses import dataclass 4 | 5 | import safetensors 6 | import torch 7 | from tensordict import TensorDict 8 | 9 | logging.getLogger("sd_meh").addHandler(logging.NullHandler()) 10 | 11 | 12 | @dataclass 13 | class SDModel: 14 | model_path: os.PathLike 15 | device: str 16 | 17 | def load_model(self): 18 | logging.info(f"Loading: {self.model_path}") 19 | if self.model_path.suffix == ".safetensors": 20 | ckpt = safetensors.torch.load_file( 21 | self.model_path, 22 | device=self.device, 23 | ) 24 | else: 25 | ckpt = torch.load(self.model_path, map_location=self.device) 26 | 27 | return TensorDict.from_dict(get_state_dict_from_checkpoint(ckpt)) 28 | 29 | 30 | # TODO: tidy up 31 | # from: stable-diffusion-webui/modules/sd_models.py 32 | def get_state_dict_from_checkpoint(pl_sd): 33 | pl_sd = pl_sd.pop("state_dict", pl_sd) 34 | pl_sd.pop("state_dict", None) 35 | sd = {} 36 | for k, v in pl_sd.items(): 37 | if new_key := transform_checkpoint_dict_key(k): 38 | sd[new_key] = v 39 | 40 | pl_sd.clear() 41 | pl_sd.update(sd) 42 | return pl_sd 43 | 44 | 45 | chckpoint_dict_replacements = { 46 | "cond_stage_model.transformer.embeddings.": "cond_stage_model.transformer.text_model.embeddings.", 47 | "cond_stage_model.transformer.encoder.": "cond_stage_model.transformer.text_model.encoder.", 48 | "cond_stage_model.transformer.final_layer_norm.": "cond_stage_model.transformer.text_model.final_layer_norm.", 49 | } 50 | 51 | 52 | def transform_checkpoint_dict_key(k): 53 | for text, replacement in chckpoint_dict_replacements.items(): 54 | if k.startswith(text): 55 | k = replacement + k[len(text) :] 56 | return k 57 | -------------------------------------------------------------------------------- /sd_meh/presets.py: -------------------------------------------------------------------------------- 1 | # fmt: off 2 | BLOCK_WEIGHTS_PRESETS = { 3 | "GRAD_V": [0, 1, 0.9166666667, 0.8333333333, 0.75, 0.6666666667, 0.5833333333, 0.5, 0.4166666667, 0.3333333333, 0.25, 0.1666666667, 0.0833333333, 0, 0.0833333333, 0.1666666667, 0.25, 0.3333333333, 0.4166666667, 0.5, 0.5833333333, 0.6666666667, 0.75, 0.8333333333, 0.9166666667, 1.0], 4 | "GRAD_A": [0, 0, 0.0833333333, 0.1666666667, 0.25, 0.3333333333, 0.4166666667, 0.5, 0.5833333333, 0.6666666667, 0.75, 0.8333333333, 0.9166666667, 1.0, 0.9166666667, 0.8333333333, 0.75, 0.6666666667, 0.5833333333, 0.5, 0.4166666667, 0.3333333333, 0.25, 0.1666666667, 0.0833333333, 0], 5 | "FLAT_25": [0, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25], 6 | "FLAT_75": [0, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75], 7 | "WRAP08": [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], 8 | "WRAP12": [0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], 9 | "WRAP14": [0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 10 | "WRAP16": [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], 11 | "MID12_50": [0, 0, 0, 0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, 0], 12 | "OUT07": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 13 | "OUT12": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 14 | "OUT12_5": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 15 | "RING08_SOFT": [0, 0, 0, 0, 0, 0, 0.5, 1, 1, 1, 0.5, 0, 0, 0, 0, 0, 0.5, 1, 1, 1, 0.5, 0, 0, 0, 0, 0], 16 | "RING08_5": [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], 17 | "RING10_5": [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], 18 | "RING10_3": [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], 19 | "SMOOTHSTEP": [0, 0, 0.00506365740740741, 0.0196759259259259, 0.04296875, 0.0740740740740741, 0.112123842592593, 0.15625, 0.205584490740741, 0.259259259259259, 0.31640625, 0.376157407407407, 0.437644675925926, 0.5, 0.562355324074074, 0.623842592592592, 0.68359375, 0.740740740740741, 0.794415509259259, 0.84375, 0.887876157407408, 0.925925925925926, 0.95703125, 0.980324074074074, 0.994936342592593, 1], 20 | "REVERSE_SMOOTHSTEP": [0, 1, 0.994936342592593, 0.980324074074074, 0.95703125, 0.925925925925926, 0.887876157407407, 0.84375, 0.794415509259259, 0.740740740740741, 0.68359375, 0.623842592592593, 0.562355324074074, 0.5, 0.437644675925926, 0.376157407407408, 0.31640625, 0.259259259259259, 0.205584490740741, 0.15625, 0.112123842592592, 0.0740740740740742, 0.0429687499999996, 0.0196759259259258, 0.00506365740740744, 0], 21 | "2SMOOTHSTEP": [0, 0, 0.0101273148148148, 0.0393518518518519, 0.0859375, 0.148148148148148, 0.224247685185185, 0.3125, 0.411168981481482, 0.518518518518519, 0.6328125, 0.752314814814815, 0.875289351851852, 1.0, 0.875289351851852, 0.752314814814815, 0.6328125, 0.518518518518519, 0.411168981481481, 0.3125, 0.224247685185184, 0.148148148148148, 0.0859375, 0.0393518518518512, 0.0101273148148153, 0], 22 | "2R_SMOOTHSTEP": [0, 1, 0.989872685185185, 0.960648148148148, 0.9140625, 0.851851851851852, 0.775752314814815, 0.6875, 0.588831018518519, 0.481481481481481, 0.3671875, 0.247685185185185, 0.124710648148148, 0.0, 0.124710648148148, 0.247685185185185, 0.3671875, 0.481481481481481, 0.588831018518519, 0.6875, 0.775752314814816, 0.851851851851852, 0.9140625, 0.960648148148149, 0.989872685185185, 1], 23 | "3SMOOTHSTEP": [0, 0, 0.0151909722222222, 0.0590277777777778, 0.12890625, 0.222222222222222, 0.336371527777778, 0.46875, 0.616753472222222, 0.777777777777778, 0.94921875, 0.871527777777778, 0.687065972222222, 0.5, 0.312934027777778, 0.128472222222222, 0.0507812500000004, 0.222222222222222, 0.383246527777778, 0.53125, 0.663628472222223, 0.777777777777778, 0.87109375, 0.940972222222222, 0.984809027777777, 1], 24 | "3R_SMOOTHSTEP": [0, 1, 0.984809027777778, 0.940972222222222, 0.87109375, 0.777777777777778, 0.663628472222222, 0.53125, 0.383246527777778, 0.222222222222222, 0.05078125, 0.128472222222222, 0.312934027777778, 0.5, 0.687065972222222, 0.871527777777778, 0.94921875, 0.777777777777778, 0.616753472222222, 0.46875, 0.336371527777777, 0.222222222222222, 0.12890625, 0.0590277777777777, 0.0151909722222232, 0], 25 | "4SMOOTHSTEP": [0, 0, 0.0202546296296296, 0.0787037037037037, 0.171875, 0.296296296296296, 0.44849537037037, 0.625, 0.822337962962963, 0.962962962962963, 0.734375, 0.49537037037037, 0.249421296296296, 0.0, 0.249421296296296, 0.495370370370371, 0.734375000000001, 0.962962962962963, 0.822337962962962, 0.625, 0.448495370370369, 0.296296296296297, 0.171875, 0.0787037037037024, 0.0202546296296307, 0], 26 | "4R_SMOOTHSTEP": [0, 1, 0.97974537037037, 0.921296296296296, 0.828125, 0.703703703703704, 0.55150462962963, 0.375, 0.177662037037037, 0.0370370370370372, 0.265625, 0.50462962962963, 0.750578703703704, 1.0, 0.750578703703704, 0.504629629629629, 0.265624999999999, 0.0370370370370372, 0.177662037037038, 0.375, 0.551504629629631, 0.703703703703703, 0.828125, 0.921296296296298, 0.979745370370369, 1], 27 | "HALF_SMOOTHSTEP": [0, 0, 0.0196759259259259, 0.0740740740740741, 0.15625, 0.259259259259259, 0.376157407407407, 0.5, 0.623842592592593, 0.740740740740741, 0.84375, 0.925925925925926, 0.980324074074074, 1.0, 0.980324074074074, 0.925925925925926, 0.84375, 0.740740740740741, 0.623842592592593, 0.5, 0.376157407407407, 0.259259259259259, 0.15625, 0.0740740740740741, 0.0196759259259259, 0], 28 | "HALF_R_SMOOTHSTEP": [0, 1, 0.980324074074074, 0.925925925925926, 0.84375, 0.740740740740741, 0.623842592592593, 0.5, 0.376157407407407, 0.259259259259259, 0.15625, 0.0740740740740742, 0.0196759259259256, 0.0, 0.0196759259259256, 0.0740740740740742, 0.15625, 0.259259259259259, 0.376157407407407, 0.5, 0.623842592592593, 0.740740740740741, 0.84375, 0.925925925925926, 0.980324074074074, 1], 29 | "ONE_THIRD_SMOOTHSTEP": [0, 0, 0.04296875, 0.15625, 0.31640625, 0.5, 0.68359375, 0.84375, 0.95703125, 1.0, 0.95703125, 0.84375, 0.68359375, 0.5, 0.31640625, 0.15625, 0.04296875, 0.0, 0.04296875, 0.15625, 0.31640625, 0.5, 0.68359375, 0.84375, 0.95703125, 1], 30 | "ONE_THIRD_R_SMOOTHSTEP": [0, 1, 0.95703125, 0.84375, 0.68359375, 0.5, 0.31640625, 0.15625, 0.04296875, 0.0, 0.04296875, 0.15625, 0.31640625, 0.5, 0.68359375, 0.84375, 0.95703125, 1.0, 0.95703125, 0.84375, 0.68359375, 0.5, 0.31640625, 0.15625, 0.04296875, 0], 31 | "ONE_FOURTH_SMOOTHSTEP": [0, 0, 0.0740740740740741, 0.259259259259259, 0.5, 0.740740740740741, 0.925925925925926, 1.0, 0.925925925925926, 0.740740740740741, 0.5, 0.259259259259259, 0.0740740740740741, 0.0, 0.0740740740740741, 0.259259259259259, 0.5, 0.740740740740741, 0.925925925925926, 1.0, 0.925925925925926, 0.740740740740741, 0.5, 0.259259259259259, 0.0740740740740741, 0], 32 | "ONE_FOURTH_R_SMOOTHSTEP": [0, 1, 0.925925925925926, 0.740740740740741, 0.5, 0.259259259259259, 0.0740740740740742, 0.0, 0.0740740740740742, 0.259259259259259, 0.5, 0.740740740740741, 0.925925925925926, 1.0, 0.925925925925926, 0.740740740740741, 0.5, 0.259259259259259, 0.0740740740740742, 0.0, 0.0740740740740742, 0.259259259259259, 0.5, 0.740740740740741, 0.925925925925926, 1], 33 | "COSINE": [0, 1, 0.995722430686905, 0.982962913144534, 0.961939766255643, 0.933012701892219, 0.896676670145617, 0.853553390593274, 0.80438071450436, 0.75, 0.691341716182545, 0.62940952255126, 0.565263096110026, 0.5, 0.434736903889974, 0.37059047744874, 0.308658283817455, 0.25, 0.195619285495639, 0.146446609406726, 0.103323329854382, 0.0669872981077805, 0.0380602337443566, 0.0170370868554658, 0.00427756931309475, 0], 34 | "REVERSE_COSINE": [0, 0, 0.00427756931309475, 0.0170370868554659, 0.0380602337443566, 0.0669872981077808, 0.103323329854383, 0.146446609406726, 0.19561928549564, 0.25, 0.308658283817455, 0.37059047744874, 0.434736903889974, 0.5, 0.565263096110026, 0.62940952255126, 0.691341716182545, 0.75, 0.804380714504361, 0.853553390593274, 0.896676670145618, 0.933012701892219, 0.961939766255643, 0.982962913144534, 0.995722430686905, 1], 35 | "TRUE_CUBIC_HERMITE": [0, 0, 0.199031876929012, 0.325761959876543, 0.424641927083333, 0.498456790123457, 0.549991560570988, 0.58203125, 0.597360869984568, 0.598765432098765, 0.589029947916667, 0.570939429012346, 0.547278886959876, 0.520833333333333, 0.49438777970679, 0.470727237654321, 0.45263671875, 0.442901234567901, 0.444305796682099, 0.459635416666667, 0.491675106095678, 0.543209876543211, 0.617024739583333, 0.715904706790124, 0.842634789737655, 1], 36 | "TRUE_REVERSE_CUBIC_HERMITE": [0, 1, 0.800968123070988, 0.674238040123457, 0.575358072916667, 0.501543209876543, 0.450008439429012, 0.41796875, 0.402639130015432, 0.401234567901235, 0.410970052083333, 0.429060570987654, 0.452721113040124, 0.479166666666667, 0.50561222029321, 0.529272762345679, 0.54736328125, 0.557098765432099, 0.555694203317901, 0.540364583333333, 0.508324893904322, 0.456790123456789, 0.382975260416667, 0.284095293209876, 0.157365210262345, 0], 37 | "FAKE_CUBIC_HERMITE": [0, 0, 0.157576195987654, 0.28491512345679, 0.384765625, 0.459876543209877, 0.512996720679012, 0.546875, 0.564260223765432, 0.567901234567901, 0.560546875, 0.544945987654321, 0.523847415123457, 0.5, 0.476152584876543, 0.455054012345679, 0.439453125, 0.432098765432099, 0.435739776234568, 0.453125, 0.487003279320987, 0.540123456790124, 0.615234375, 0.71508487654321, 0.842423804012347, 1], 38 | "FAKE_REVERSE_CUBIC_HERMITE": [0, 1, 0.842423804012346, 0.71508487654321, 0.615234375, 0.540123456790123, 0.487003279320988, 0.453125, 0.435739776234568, 0.432098765432099, 0.439453125, 0.455054012345679, 0.476152584876543, 0.5, 0.523847415123457, 0.544945987654321, 0.560546875, 0.567901234567901, 0.564260223765432, 0.546875, 0.512996720679013, 0.459876543209876, 0.384765625, 0.28491512345679, 0.157576195987653, 0], 39 | "ALL_A": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 40 | "ALL_B": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 41 | "LOW_OFFSET_CUBIC_HERMITE": [0, 0, 0.099515938464506, 0.1628809799382715, 0.2123209635416665, 0.249228395061729, 0.274995780285494, 0.291015625, 0.298680434992284, 0.2993827160493825, 0.294514973958333, 0.285469714506173, 0.273639443479938, 0.261513611593364, 0.24938777970679, 0.245727237654321, 0.23763671875, 0.222901234567901, 0.224305796682099, 0.234635416666667, 0.247675106095678, 0.273209876543211, 0.312024739583333, 0.360904706790124, 0.422634789737655, 0.5], 42 | } 43 | # fmt: on 44 | -------------------------------------------------------------------------------- /sd_meh/rebasin.py: -------------------------------------------------------------------------------- 1 | # https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion 2 | import logging 3 | from collections import defaultdict 4 | from random import shuffle 5 | from typing import Dict, NamedTuple, Tuple 6 | 7 | import torch 8 | from scipy.optimize import linear_sum_assignment 9 | 10 | logging.getLogger("sd_meh").addHandler(logging.NullHandler()) 11 | SPECIAL_KEYS = [ 12 | "first_stage_model.decoder.norm_out.weight", 13 | "first_stage_model.decoder.norm_out.bias", 14 | "first_stage_model.encoder.norm_out.weight", 15 | "first_stage_model.encoder.norm_out.bias", 16 | "model.diffusion_model.out.0.weight", 17 | "model.diffusion_model.out.0.bias", 18 | ] 19 | 20 | 21 | def step_weights_and_bases( 22 | weights: Dict, bases: Dict, it: int = 0, iterations: int = 1 23 | ) -> Tuple[Dict, Dict]: 24 | new_weights = { 25 | gl: [ 26 | 1 - (1 - (1 + it) * v / iterations) / (1 - it * v / iterations) 27 | if it > 0 28 | else v / iterations 29 | for v in w 30 | ] 31 | for gl, w in weights.items() 32 | } 33 | 34 | new_bases = { 35 | k: 1 - (1 - (1 + it) * v / iterations) / (1 - it * v / iterations) 36 | if it > 0 37 | else v / iterations 38 | for k, v in bases.items() 39 | } 40 | 41 | return new_weights, new_bases 42 | 43 | 44 | def flatten_params(model): 45 | return model["state_dict"] 46 | 47 | 48 | rngmix = lambda rng, x: random.fold_in(rng, hash(x)) 49 | 50 | 51 | class PermutationSpec(NamedTuple): 52 | perm_to_axes: dict 53 | axes_to_perm: dict 54 | 55 | 56 | def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec: 57 | perm_to_axes = defaultdict(list) 58 | for wk, axis_perms in axes_to_perm.items(): 59 | for axis, perm in enumerate(axis_perms): 60 | if perm is not None: 61 | perm_to_axes[perm].append((wk, axis)) 62 | return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm) 63 | 64 | 65 | def sdunet_permutation_spec() -> PermutationSpec: 66 | conv = lambda name, p_in, p_out: { 67 | f"{name}.weight": ( 68 | p_out, 69 | p_in, 70 | ), 71 | f"{name}.bias": (p_out,), 72 | } 73 | norm = lambda name, p: {f"{name}.weight": (p,), f"{name}.bias": (p,)} 74 | dense = ( 75 | lambda name, p_in, p_out, bias=True: { 76 | f"{name}.weight": (p_out, p_in), 77 | f"{name}.bias": (p_out,), 78 | } 79 | if bias 80 | else {f"{name}.weight": (p_out, p_in)} 81 | ) 82 | skip = lambda name, p_in, p_out: { 83 | f"{name}": ( 84 | p_out, 85 | p_in, 86 | None, 87 | None, 88 | ) 89 | } 90 | 91 | # Unet Res blocks 92 | easyblock = lambda name, p_in, p_out: { 93 | **norm(f"{name}.in_layers.0", p_in), 94 | **conv(f"{name}.in_layers.2", p_in, f"P_{name}_inner"), 95 | **dense( 96 | f"{name}.emb_layers.1", f"P_{name}_inner2", f"P_{name}_inner3", bias=True 97 | ), 98 | **norm(f"{name}.out_layers.0", f"P_{name}_inner4"), 99 | **conv(f"{name}.out_layers.3", f"P_{name}_inner4", p_out), 100 | } 101 | 102 | # Text Encoder blocks 103 | easyblock2 = lambda name, p: { 104 | **norm(f"{name}.norm1", p), 105 | **conv(f"{name}.conv1", p, f"P_{name}_inner"), 106 | **norm(f"{name}.norm2", f"P_{name}_inner"), 107 | **conv(f"{name}.conv2", f"P_{name}_inner", p), 108 | } 109 | 110 | # This is for blocks that use a residual connection, but change the number of channels via a Conv. 111 | shortcutblock = lambda name, p_in, p_out: { 112 | **norm(f"{name}.norm1", p_in), 113 | **conv(f"{name}.conv1", p_in, f"P_{name}_inner"), 114 | **norm(f"{name}.norm2", f"P_{name}_inner"), 115 | **conv(f"{name}.conv2", f"P_{name}_inner", p_out), 116 | **conv(f"{name}.nin_shortcut", p_in, p_out), 117 | **norm(f"{name}.nin_shortcut", p_out), 118 | } 119 | 120 | return permutation_spec_from_axes_to_perm( 121 | { 122 | # Skipped Layers 123 | **skip("betas", None, None), 124 | **skip("alphas_cumprod", None, None), 125 | **skip("alphas_cumprod_prev", None, None), 126 | **skip("sqrt_alphas_cumprod", None, None), 127 | **skip("sqrt_one_minus_alphas_cumprod", None, None), 128 | **skip("log_one_minus_alphas_cumprods", None, None), 129 | **skip("sqrt_recip_alphas_cumprod", None, None), 130 | **skip("sqrt_recipm1_alphas_cumprod", None, None), 131 | **skip("posterior_variance", None, None), 132 | **skip("posterior_log_variance_clipped", None, None), 133 | **skip("posterior_mean_coef1", None, None), 134 | **skip("posterior_mean_coef2", None, None), 135 | **skip("log_one_minus_alphas_cumprod", None, None), 136 | **skip("model_ema.decay", None, None), 137 | **skip("model_ema.num_updates", None, None), 138 | # initial 139 | **dense("model.diffusion_model.time_embed.0", None, "P_bg0", bias=True), 140 | **dense("model.diffusion_model.time_embed.2", "P_bg0", "P_bg1", bias=True), 141 | **conv("model.diffusion_model.input_blocks.0.0", "P_bg2", "P_bg3"), 142 | # input blocks 143 | **easyblock("model.diffusion_model.input_blocks.1.0", "P_bg4", "P_bg5"), 144 | **norm("model.diffusion_model.input_blocks.1.1.norm", "P_bg6"), 145 | **conv("model.diffusion_model.input_blocks.1.1.proj_in", "P_bg6", "P_bg7"), 146 | **dense( 147 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q", 148 | "P_bg8", 149 | "P_bg9", 150 | bias=False, 151 | ), 152 | **dense( 153 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k", 154 | "P_bg8", 155 | "P_bg9", 156 | bias=False, 157 | ), 158 | **dense( 159 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v", 160 | "P_bg8", 161 | "P_bg9", 162 | bias=False, 163 | ), 164 | **dense( 165 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0", 166 | "P_bg8", 167 | "P_bg9", 168 | bias=True, 169 | ), 170 | **dense( 171 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj", 172 | "P_bg10", 173 | "P_bg11", 174 | bias=True, 175 | ), 176 | **dense( 177 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2", 178 | "P_bg12", 179 | "P_bg13", 180 | bias=True, 181 | ), 182 | **dense( 183 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q", 184 | "P_bg14", 185 | "P_bg15", 186 | bias=False, 187 | ), 188 | **dense( 189 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k", 190 | "P_bg16", 191 | "P_bg17", 192 | bias=False, 193 | ), 194 | **dense( 195 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v", 196 | "P_bg16", 197 | "P_bg17", 198 | bias=False, 199 | ), 200 | **dense( 201 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0", 202 | "P_bg18", 203 | "P_bg19", 204 | bias=True, 205 | ), 206 | **norm( 207 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1", 208 | "P_bg19", 209 | ), 210 | **norm( 211 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2", 212 | "P_bg19", 213 | ), 214 | **norm( 215 | "model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3", 216 | "P_bg19", 217 | ), 218 | **conv( 219 | "model.diffusion_model.input_blocks.1.1.proj_out", "P_bg19", "P_bg20" 220 | ), 221 | **easyblock("model.diffusion_model.input_blocks.2.0", "P_bg21", "P_bg22"), 222 | **norm("model.diffusion_model.input_blocks.2.1.norm", "P_bg23"), 223 | **conv( 224 | "model.diffusion_model.input_blocks.2.1.proj_in", "P_bg23", "P_bg24" 225 | ), 226 | **dense( 227 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q", 228 | "P_bg25", 229 | "P_bg26", 230 | bias=False, 231 | ), 232 | **dense( 233 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k", 234 | "P_bg25", 235 | "P_bg26", 236 | bias=False, 237 | ), 238 | **dense( 239 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v", 240 | "P_bg25", 241 | "P_bg26", 242 | bias=False, 243 | ), 244 | **dense( 245 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0", 246 | "P_bg25", 247 | "P_bg26", 248 | bias=True, 249 | ), 250 | **dense( 251 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj", 252 | "P_bg27", 253 | "P_bg28", 254 | bias=True, 255 | ), 256 | **dense( 257 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2", 258 | "P_bg29", 259 | "P_bg30", 260 | bias=True, 261 | ), 262 | **dense( 263 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q", 264 | "P_bg31", 265 | "P_bg32", 266 | bias=False, 267 | ), 268 | **dense( 269 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k", 270 | "P_bg33", 271 | "P_bg34", 272 | bias=False, 273 | ), 274 | **dense( 275 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v", 276 | "P_bg33", 277 | "P_bg34", 278 | bias=False, 279 | ), 280 | **dense( 281 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0", 282 | "P_bg35", 283 | "P_bg36", 284 | bias=True, 285 | ), 286 | **norm( 287 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1", 288 | "P_bg36", 289 | ), 290 | **norm( 291 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2", 292 | "P_bg36", 293 | ), 294 | **norm( 295 | "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3", 296 | "P_bg36", 297 | ), 298 | **conv( 299 | "model.diffusion_model.input_blocks.2.1.proj_out", "P_bg36", "P_bg37" 300 | ), 301 | **conv("model.diffusion_model.input_blocks.3.0.op", "P_bg38", "P_bg39"), 302 | **easyblock("model.diffusion_model.input_blocks.4.0", "P_bg40", "P_bg41"), 303 | **conv( 304 | "model.diffusion_model.input_blocks.4.0.skip_connection", 305 | "P_bg42", 306 | "P_bg43", 307 | ), 308 | **norm("model.diffusion_model.input_blocks.4.1.norm", "P_bg44"), 309 | **conv( 310 | "model.diffusion_model.input_blocks.4.1.proj_in", "P_bg44", "P_bg45" 311 | ), 312 | **dense( 313 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q", 314 | "P_bg46", 315 | "P_bg47", 316 | bias=False, 317 | ), 318 | **dense( 319 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k", 320 | "P_bg46", 321 | "P_bg47", 322 | bias=False, 323 | ), 324 | **dense( 325 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v", 326 | "P_bg46", 327 | "P_bg47", 328 | bias=False, 329 | ), 330 | **dense( 331 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0", 332 | "P_bg46", 333 | "P_bg47", 334 | bias=True, 335 | ), 336 | **dense( 337 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj", 338 | "P_bg48", 339 | "P_bg49", 340 | bias=True, 341 | ), 342 | **dense( 343 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2", 344 | "P_bg50", 345 | "P_bg51", 346 | bias=True, 347 | ), 348 | **dense( 349 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q", 350 | "P_bg52", 351 | "P_bg53", 352 | bias=False, 353 | ), 354 | **dense( 355 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k", 356 | "P_bg54", 357 | "P_bg55", 358 | bias=False, 359 | ), 360 | **dense( 361 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v", 362 | "P_bg54", 363 | "P_bg55", 364 | bias=False, 365 | ), 366 | **dense( 367 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0", 368 | "P_bg56", 369 | "P_bg57", 370 | bias=True, 371 | ), 372 | **norm( 373 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1", 374 | "P_bg57", 375 | ), 376 | **norm( 377 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2", 378 | "P_bg57", 379 | ), 380 | **norm( 381 | "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3", 382 | "P_bg57", 383 | ), 384 | **conv( 385 | "model.diffusion_model.input_blocks.4.1.proj_out", "P_bg57", "P_bg58" 386 | ), 387 | **easyblock("model.diffusion_model.input_blocks.5.0", "P_bg59", "P_bg60"), 388 | **norm("model.diffusion_model.input_blocks.5.1.norm", "P_bg61"), 389 | **conv( 390 | "model.diffusion_model.input_blocks.5.1.proj_in", "P_bg61", "P_bg62" 391 | ), 392 | **dense( 393 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q", 394 | "P_bg63", 395 | "P_bg64", 396 | bias=False, 397 | ), 398 | **dense( 399 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k", 400 | "P_bg63", 401 | "P_bg64", 402 | bias=False, 403 | ), 404 | **dense( 405 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v", 406 | "P_bg63", 407 | "P_bg64", 408 | bias=False, 409 | ), 410 | **dense( 411 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0", 412 | "P_bg63", 413 | "P_bg64", 414 | bias=True, 415 | ), 416 | **dense( 417 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj", 418 | "P_bg65", 419 | "P_bg66", 420 | bias=True, 421 | ), 422 | **dense( 423 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2", 424 | "P_bg67", 425 | "P_bg68", 426 | bias=True, 427 | ), 428 | **dense( 429 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q", 430 | "P_bg69", 431 | "P_bg70", 432 | bias=False, 433 | ), 434 | **dense( 435 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k", 436 | "P_bg71", 437 | "P_bg72", 438 | bias=False, 439 | ), 440 | **dense( 441 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v", 442 | "P_bg71", 443 | "P_bg72", 444 | bias=False, 445 | ), 446 | **dense( 447 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0", 448 | "P_bg73", 449 | "P_bg74", 450 | bias=True, 451 | ), 452 | **norm( 453 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1", 454 | "P_bg74", 455 | ), 456 | **norm( 457 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2", 458 | "P_bg74", 459 | ), 460 | **norm( 461 | "model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3", 462 | "P_bg74", 463 | ), 464 | **conv( 465 | "model.diffusion_model.input_blocks.5.1.proj_out", "P_bg74", "P_bg75" 466 | ), 467 | **conv("model.diffusion_model.input_blocks.6.0.op", "P_bg76", "P_bg77"), 468 | **easyblock("model.diffusion_model.input_blocks.7.0", "P_bg78", "P_bg79"), 469 | **conv( 470 | "model.diffusion_model.input_blocks.7.0.skip_connection", 471 | "P_bg80", 472 | "P_bg81", 473 | ), 474 | **norm("model.diffusion_model.input_blocks.7.1.norm", "P_bg82"), 475 | **conv( 476 | "model.diffusion_model.input_blocks.7.1.proj_in", "P_bg82", "P_bg83" 477 | ), 478 | **dense( 479 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q", 480 | "P_bg84", 481 | "P_bg85", 482 | bias=False, 483 | ), 484 | **dense( 485 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k", 486 | "P_bg84", 487 | "P_bg85", 488 | bias=False, 489 | ), 490 | **dense( 491 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v", 492 | "P_bg84", 493 | "P_bg85", 494 | bias=False, 495 | ), 496 | **dense( 497 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0", 498 | "P_bg84", 499 | "P_bg85", 500 | bias=True, 501 | ), 502 | **dense( 503 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj", 504 | "P_bg86", 505 | "P_bg87", 506 | bias=True, 507 | ), 508 | **dense( 509 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2", 510 | "P_bg88", 511 | "P_bg89", 512 | bias=True, 513 | ), 514 | **dense( 515 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q", 516 | "P_bg90", 517 | "P_bg91", 518 | bias=False, 519 | ), 520 | **dense( 521 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k", 522 | "P_bg92", 523 | "P_bg93", 524 | bias=False, 525 | ), 526 | **dense( 527 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v", 528 | "P_bg92", 529 | "P_bg93", 530 | bias=False, 531 | ), 532 | **dense( 533 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0", 534 | "P_bg94", 535 | "P_bg95", 536 | bias=True, 537 | ), 538 | **norm( 539 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1", 540 | "P_bg95", 541 | ), 542 | **norm( 543 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2", 544 | "P_bg95", 545 | ), 546 | **norm( 547 | "model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3", 548 | "P_bg95", 549 | ), 550 | **conv( 551 | "model.diffusion_model.input_blocks.7.1.proj_out", "P_bg95", "P_bg96" 552 | ), 553 | **easyblock("model.diffusion_model.input_blocks.8.0", "P_bg97", "P_bg98"), 554 | **norm("model.diffusion_model.input_blocks.8.1.norm", "P_bg99"), 555 | **conv( 556 | "model.diffusion_model.input_blocks.8.1.proj_in", "P_bg99", "P_bg100" 557 | ), 558 | **dense( 559 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q", 560 | "P_bg101", 561 | "P_bg102", 562 | bias=False, 563 | ), 564 | **dense( 565 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k", 566 | "P_bg101", 567 | "P_bg102", 568 | bias=False, 569 | ), 570 | **dense( 571 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v", 572 | "P_bg101", 573 | "P_bg102", 574 | bias=False, 575 | ), 576 | **dense( 577 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0", 578 | "P_bg101", 579 | "P_bg102", 580 | bias=True, 581 | ), 582 | **dense( 583 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj", 584 | "P_bg103", 585 | "P_bg104", 586 | bias=True, 587 | ), 588 | **dense( 589 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2", 590 | "P_bg105", 591 | "P_bg106", 592 | bias=True, 593 | ), 594 | **dense( 595 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q", 596 | "P_bg107", 597 | "P_bg108", 598 | bias=False, 599 | ), 600 | **dense( 601 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k", 602 | "P_bg109", 603 | "P_bg110", 604 | bias=False, 605 | ), 606 | **dense( 607 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v", 608 | "P_bg109", 609 | "P_bg110", 610 | bias=False, 611 | ), 612 | **dense( 613 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0", 614 | "P_bg111", 615 | "P_bg112", 616 | bias=True, 617 | ), 618 | **norm( 619 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1", 620 | "P_bg112", 621 | ), 622 | **norm( 623 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2", 624 | "P_bg112", 625 | ), 626 | **norm( 627 | "model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3", 628 | "P_bg112", 629 | ), 630 | **conv( 631 | "model.diffusion_model.input_blocks.8.1.proj_out", "P_bg112", "P_bg113" 632 | ), 633 | **conv("model.diffusion_model.input_blocks.9.0.op", "P_bg114", "P_bg115"), 634 | **easyblock( 635 | "model.diffusion_model.input_blocks.10.0", "P_bg115", "P_bg116" 636 | ), 637 | **easyblock( 638 | "model.diffusion_model.input_blocks.11.0", "P_bg116", "P_bg117" 639 | ), 640 | # middle blocks 641 | **easyblock("model.diffusion_model.middle_block.0", "P_bg117", "P_bg118"), 642 | **norm("model.diffusion_model.middle_block.1.norm", "P_bg119"), 643 | **conv( 644 | "model.diffusion_model.middle_block.1.proj_in", "P_bg119", "P_bg120" 645 | ), 646 | **dense( 647 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q", 648 | "P_bg121", 649 | "P_bg122", 650 | bias=False, 651 | ), 652 | **dense( 653 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k", 654 | "P_bg121", 655 | "P_bg122", 656 | bias=False, 657 | ), 658 | **dense( 659 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v", 660 | "P_bg121", 661 | "P_bg122", 662 | bias=False, 663 | ), 664 | **dense( 665 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0", 666 | "P_bg121", 667 | "P_bg122", 668 | bias=True, 669 | ), 670 | **dense( 671 | "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj", 672 | "P_bg123", 673 | "P_bg124", 674 | bias=True, 675 | ), 676 | **dense( 677 | "model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2", 678 | "P_bg125", 679 | "P_bg126", 680 | bias=True, 681 | ), 682 | **dense( 683 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q", 684 | "P_bg127", 685 | "P_bg128", 686 | bias=False, 687 | ), 688 | **dense( 689 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k", 690 | "P_bg129", 691 | "P_bg130", 692 | bias=False, 693 | ), 694 | **dense( 695 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v", 696 | "P_bg129", 697 | "P_bg130", 698 | bias=False, 699 | ), 700 | **dense( 701 | "model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0", 702 | "P_bg131", 703 | "P_bg132", 704 | bias=True, 705 | ), 706 | **norm( 707 | "model.diffusion_model.middle_block.1.transformer_blocks.0.norm1", 708 | "P_bg132", 709 | ), 710 | **norm( 711 | "model.diffusion_model.middle_block.1.transformer_blocks.0.norm2", 712 | "P_bg132", 713 | ), 714 | **norm( 715 | "model.diffusion_model.middle_block.1.transformer_blocks.0.norm3", 716 | "P_bg132", 717 | ), 718 | **conv( 719 | "model.diffusion_model.middle_block.1.proj_out", "P_bg132", "P_bg133" 720 | ), 721 | **easyblock("model.diffusion_model.middle_block.2", "P_bg134", "P_bg135"), 722 | # output blocks 723 | **easyblock( 724 | "model.diffusion_model.output_blocks.0.0", "P_bg136", "P_bg137" 725 | ), 726 | **conv( 727 | "model.diffusion_model.output_blocks.0.0.skip_connection", 728 | "P_bg138", 729 | "P_bg139", 730 | ), 731 | **easyblock( 732 | "model.diffusion_model.output_blocks.1.0", "P_bg140", "P_bg141" 733 | ), 734 | **conv( 735 | "model.diffusion_model.output_blocks.1.0.skip_connection", 736 | "P_bg142", 737 | "P_bg143", 738 | ), 739 | **easyblock( 740 | "model.diffusion_model.output_blocks.2.0", "P_bg144", "P_bg145" 741 | ), 742 | **conv( 743 | "model.diffusion_model.output_blocks.2.0.skip_connection", 744 | "P_bg146", 745 | "P_bg147", 746 | ), 747 | **conv( 748 | "model.diffusion_model.output_blocks.2.1.conv", "P_bg148", "P_bg149" 749 | ), 750 | **easyblock( 751 | "model.diffusion_model.output_blocks.3.0", "P_bg150", "P_bg151" 752 | ), 753 | **conv( 754 | "model.diffusion_model.output_blocks.3.0.skip_connection", 755 | "P_bg152", 756 | "P_bg153", 757 | ), 758 | **norm("model.diffusion_model.output_blocks.3.1.norm", "P_bg154"), 759 | **conv( 760 | "model.diffusion_model.output_blocks.3.1.proj_in", "P_bg154", "P_bg155" 761 | ), 762 | **dense( 763 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q", 764 | "P_bg156", 765 | "P_bg157", 766 | bias=False, 767 | ), 768 | **dense( 769 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k", 770 | "P_bg156", 771 | "P_bg157", 772 | bias=False, 773 | ), 774 | **dense( 775 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v", 776 | "P_bg156", 777 | "P_bg157", 778 | bias=False, 779 | ), 780 | **dense( 781 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0", 782 | "P_bg156", 783 | "P_bg157", 784 | bias=True, 785 | ), 786 | **dense( 787 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj", 788 | "P_bg158", 789 | "P_bg159", 790 | bias=True, 791 | ), 792 | **dense( 793 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2", 794 | "P_bg160", 795 | "P_bg161", 796 | bias=True, 797 | ), 798 | **dense( 799 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q", 800 | "P_bg162", 801 | "P_bg163", 802 | bias=False, 803 | ), 804 | **dense( 805 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k", 806 | "P_bg164", 807 | "P_bg165", 808 | bias=False, 809 | ), 810 | **dense( 811 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v", 812 | "P_bg164", 813 | "P_bg165", 814 | bias=False, 815 | ), 816 | **dense( 817 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0", 818 | "P_bg166", 819 | "P_bg167", 820 | bias=True, 821 | ), 822 | **norm( 823 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1", 824 | "P_bg167", 825 | ), 826 | **norm( 827 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2", 828 | "P_bg167", 829 | ), 830 | **norm( 831 | "model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3", 832 | "P_bg167", 833 | ), 834 | **conv( 835 | "model.diffusion_model.output_blocks.3.1.proj_out", "P_bg167", "P_bg168" 836 | ), 837 | **easyblock( 838 | "model.diffusion_model.output_blocks.4.0", "P_bg169", "P_bg170" 839 | ), 840 | **conv( 841 | "model.diffusion_model.output_blocks.4.0.skip_connection", 842 | "P_bg171", 843 | "P_bg172", 844 | ), 845 | **norm("model.diffusion_model.output_blocks.4.1.norm", "P_bg173"), 846 | **conv( 847 | "model.diffusion_model.output_blocks.4.1.proj_in", "P_bg173", "P_bg174" 848 | ), 849 | **dense( 850 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q", 851 | "P_bg175", 852 | "P_bg176", 853 | bias=False, 854 | ), 855 | **dense( 856 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k", 857 | "P_bg175", 858 | "P_bg176", 859 | bias=False, 860 | ), 861 | **dense( 862 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v", 863 | "P_bg175", 864 | "P_bg176", 865 | bias=False, 866 | ), 867 | **dense( 868 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0", 869 | "P_bg175", 870 | "P_bg176", 871 | bias=True, 872 | ), 873 | **dense( 874 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj", 875 | "P_bg177", 876 | "P_bg178", 877 | bias=True, 878 | ), 879 | **dense( 880 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2", 881 | "P_bg179", 882 | "P_bg180", 883 | bias=True, 884 | ), 885 | **dense( 886 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q", 887 | "P_bg181", 888 | "P_bg182", 889 | bias=False, 890 | ), 891 | **dense( 892 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k", 893 | "P_bg183", 894 | "P_bg184", 895 | bias=False, 896 | ), 897 | **dense( 898 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v", 899 | "P_bg183", 900 | "P_bg184", 901 | bias=False, 902 | ), 903 | **dense( 904 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0", 905 | "P_bg185", 906 | "P_bg186", 907 | bias=True, 908 | ), 909 | **norm( 910 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1", 911 | "P_bg186", 912 | ), 913 | **norm( 914 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2", 915 | "P_bg186", 916 | ), 917 | **norm( 918 | "model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3", 919 | "P_bg186", 920 | ), 921 | **conv( 922 | "model.diffusion_model.output_blocks.4.1.proj_out", "P_bg186", "P_bg187" 923 | ), 924 | **easyblock( 925 | "model.diffusion_model.output_blocks.5.0", "P_bg188", "P_bg189" 926 | ), 927 | **conv( 928 | "model.diffusion_model.output_blocks.5.0.skip_connection", 929 | "P_bg190", 930 | "P_bg191", 931 | ), 932 | **norm("model.diffusion_model.output_blocks.5.1.norm", "P_bg192"), 933 | **conv( 934 | "model.diffusion_model.output_blocks.5.1.proj_in", "P_bg192", "P_bg193" 935 | ), 936 | **dense( 937 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q", 938 | "P_bg194", 939 | "P_bg195", 940 | bias=False, 941 | ), 942 | **dense( 943 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k", 944 | "P_bg194", 945 | "P_bg195", 946 | bias=False, 947 | ), 948 | **dense( 949 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v", 950 | "P_bg194", 951 | "P_bg195", 952 | bias=False, 953 | ), 954 | **dense( 955 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0", 956 | "P_bg194", 957 | "P_bg195", 958 | bias=True, 959 | ), 960 | **dense( 961 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj", 962 | "P_bg196", 963 | "P_bg197", 964 | bias=True, 965 | ), 966 | **dense( 967 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2", 968 | "P_bg198", 969 | "P_bg199", 970 | bias=True, 971 | ), 972 | **dense( 973 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q", 974 | "P_bg200", 975 | "P_bg201", 976 | bias=False, 977 | ), 978 | **dense( 979 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k", 980 | "P_bg202", 981 | "P_bg203", 982 | bias=False, 983 | ), 984 | **dense( 985 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v", 986 | "P_bg202", 987 | "P_bg203", 988 | bias=False, 989 | ), 990 | **dense( 991 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0", 992 | "P_bg204", 993 | "P_bg205", 994 | bias=True, 995 | ), 996 | **norm( 997 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1", 998 | "P_bg205", 999 | ), 1000 | **norm( 1001 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2", 1002 | "P_bg205", 1003 | ), 1004 | **norm( 1005 | "model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3", 1006 | "P_bg205", 1007 | ), 1008 | **conv( 1009 | "model.diffusion_model.output_blocks.5.1.proj_out", "P_bg205", "P_bg206" 1010 | ), 1011 | **conv( 1012 | "model.diffusion_model.output_blocks.5.2.conv", "P_bg206", "P_bg207" 1013 | ), 1014 | **easyblock( 1015 | "model.diffusion_model.output_blocks.6.0", "P_bg208", "P_bg209" 1016 | ), 1017 | **conv( 1018 | "model.diffusion_model.output_blocks.6.0.skip_connection", 1019 | "P_bg210", 1020 | "P_bg211", 1021 | ), 1022 | **norm("model.diffusion_model.output_blocks.6.1.norm", "P_bg212"), 1023 | **conv( 1024 | "model.diffusion_model.output_blocks.6.1.proj_in", "P_bg212", "P_bg213" 1025 | ), 1026 | **dense( 1027 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q", 1028 | "P_bg214", 1029 | "P_bg215", 1030 | bias=False, 1031 | ), 1032 | **dense( 1033 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k", 1034 | "P_bg214", 1035 | "P_bg215", 1036 | bias=False, 1037 | ), 1038 | **dense( 1039 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v", 1040 | "P_bg214", 1041 | "P_bg215", 1042 | bias=False, 1043 | ), 1044 | **dense( 1045 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0", 1046 | "P_bg214", 1047 | "P_bg215", 1048 | bias=True, 1049 | ), 1050 | **dense( 1051 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj", 1052 | "P_bg216", 1053 | "P_bg217", 1054 | bias=True, 1055 | ), 1056 | **dense( 1057 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2", 1058 | "P_bg218", 1059 | "P_bg219", 1060 | bias=True, 1061 | ), 1062 | **dense( 1063 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q", 1064 | "P_bg220", 1065 | "P_bg221", 1066 | bias=False, 1067 | ), 1068 | **dense( 1069 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k", 1070 | "P_bg222", 1071 | "P_bg223", 1072 | bias=False, 1073 | ), 1074 | **dense( 1075 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v", 1076 | "P_bg222", 1077 | "P_bg223", 1078 | bias=False, 1079 | ), 1080 | **dense( 1081 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0", 1082 | "P_bg224", 1083 | "P_bg225", 1084 | bias=True, 1085 | ), 1086 | **norm( 1087 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1", 1088 | "P_bg225", 1089 | ), 1090 | **norm( 1091 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2", 1092 | "P_bg225", 1093 | ), 1094 | **norm( 1095 | "model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3", 1096 | "P_bg225", 1097 | ), 1098 | **conv( 1099 | "model.diffusion_model.output_blocks.6.1.proj_out", "P_bg225", "P_bg226" 1100 | ), 1101 | **easyblock( 1102 | "model.diffusion_model.output_blocks.7.0", "P_bg227", "P_bg228" 1103 | ), 1104 | **conv( 1105 | "model.diffusion_model.output_blocks.7.0.skip_connection", 1106 | "P_bg229", 1107 | "P_bg230", 1108 | ), 1109 | **norm("model.diffusion_model.output_blocks.7.1.norm", "P_bg231"), 1110 | **conv( 1111 | "model.diffusion_model.output_blocks.7.1.proj_in", "P_bg231", "P_bg232" 1112 | ), 1113 | **dense( 1114 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q", 1115 | "P_bg233", 1116 | "P_bg234", 1117 | bias=False, 1118 | ), 1119 | **dense( 1120 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k", 1121 | "P_bg233", 1122 | "P_bg234", 1123 | bias=False, 1124 | ), 1125 | **dense( 1126 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v", 1127 | "P_bg233", 1128 | "P_bg234", 1129 | bias=False, 1130 | ), 1131 | **dense( 1132 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0", 1133 | "P_bg233", 1134 | "P_bg234", 1135 | bias=True, 1136 | ), 1137 | **dense( 1138 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj", 1139 | "P_bg235", 1140 | "P_bg236", 1141 | bias=True, 1142 | ), 1143 | **dense( 1144 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2", 1145 | "P_bg237", 1146 | "P_bg238", 1147 | bias=True, 1148 | ), 1149 | **dense( 1150 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q", 1151 | "P_bg239", 1152 | "P_bg240", 1153 | bias=False, 1154 | ), 1155 | **dense( 1156 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k", 1157 | "P_bg241", 1158 | "P_bg242", 1159 | bias=False, 1160 | ), 1161 | **dense( 1162 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v", 1163 | "P_bg241", 1164 | "P_bg242", 1165 | bias=False, 1166 | ), 1167 | **dense( 1168 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0", 1169 | "P_bg243", 1170 | "P_bg244", 1171 | bias=True, 1172 | ), 1173 | **norm( 1174 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1", 1175 | "P_bg244", 1176 | ), 1177 | **norm( 1178 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2", 1179 | "P_bg244", 1180 | ), 1181 | **norm( 1182 | "model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3", 1183 | "P_bg244", 1184 | ), 1185 | **conv( 1186 | "model.diffusion_model.output_blocks.7.1.proj_out", "P_bg244", "P_bg245" 1187 | ), 1188 | **easyblock( 1189 | "model.diffusion_model.output_blocks.8.0", "P_bg246", "P_bg247" 1190 | ), 1191 | **conv( 1192 | "model.diffusion_model.output_blocks.8.0.skip_connection", 1193 | "P_bg248", 1194 | "P_bg249", 1195 | ), 1196 | **norm("model.diffusion_model.output_blocks.8.1.norm", "P_bg250"), 1197 | **conv( 1198 | "model.diffusion_model.output_blocks.8.1.proj_in", "P_bg250", "P_bg251" 1199 | ), 1200 | **dense( 1201 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q", 1202 | "P_bg252", 1203 | "P_bg253", 1204 | bias=False, 1205 | ), 1206 | **dense( 1207 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k", 1208 | "P_bg252", 1209 | "P_bg253", 1210 | bias=False, 1211 | ), 1212 | **dense( 1213 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v", 1214 | "P_bg252", 1215 | "P_bg253", 1216 | bias=False, 1217 | ), 1218 | **dense( 1219 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0", 1220 | "P_bg252", 1221 | "P_bg253", 1222 | bias=True, 1223 | ), 1224 | **dense( 1225 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj", 1226 | "P_bg254", 1227 | "P_bg255", 1228 | bias=True, 1229 | ), 1230 | **dense( 1231 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2", 1232 | "P_bg256", 1233 | "P_bg257", 1234 | bias=True, 1235 | ), 1236 | **dense( 1237 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q", 1238 | "P_bg258", 1239 | "P_bg259", 1240 | bias=False, 1241 | ), 1242 | **dense( 1243 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k", 1244 | "P_bg260", 1245 | "P_bg261", 1246 | bias=False, 1247 | ), 1248 | **dense( 1249 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v", 1250 | "P_bg260", 1251 | "P_bg261", 1252 | bias=False, 1253 | ), 1254 | **dense( 1255 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0", 1256 | "P_bg262", 1257 | "P_bg263", 1258 | bias=True, 1259 | ), 1260 | **norm( 1261 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1", 1262 | "P_bg263", 1263 | ), 1264 | **norm( 1265 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2", 1266 | "P_bg263", 1267 | ), 1268 | **norm( 1269 | "model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3", 1270 | "P_bg263", 1271 | ), 1272 | **conv( 1273 | "model.diffusion_model.output_blocks.8.1.proj_out", "P_bg263", "P_bg264" 1274 | ), 1275 | **conv( 1276 | "model.diffusion_model.output_blocks.8.2.conv", "P_bg265", "P_bg266" 1277 | ), 1278 | **easyblock( 1279 | "model.diffusion_model.output_blocks.9.0", "P_bg267", "P_bg268" 1280 | ), 1281 | **conv( 1282 | "model.diffusion_model.output_blocks.9.0.skip_connection", 1283 | "P_bg269", 1284 | "P_bg270", 1285 | ), 1286 | **norm("model.diffusion_model.output_blocks.9.1.norm", "P_bg271"), 1287 | **conv( 1288 | "model.diffusion_model.output_blocks.9.1.proj_in", "P_bg271", "P_bg272" 1289 | ), 1290 | **dense( 1291 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q", 1292 | "P_bg273", 1293 | "P_bg274", 1294 | bias=False, 1295 | ), 1296 | **dense( 1297 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k", 1298 | "P_bg273", 1299 | "P_bg274", 1300 | bias=False, 1301 | ), 1302 | **dense( 1303 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v", 1304 | "P_bg273", 1305 | "P_bg274", 1306 | bias=False, 1307 | ), 1308 | **dense( 1309 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0", 1310 | "P_bg273", 1311 | "P_bg274", 1312 | bias=True, 1313 | ), 1314 | **dense( 1315 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj", 1316 | "P_bg275", 1317 | "P_bg276", 1318 | bias=True, 1319 | ), 1320 | **dense( 1321 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2", 1322 | "P_bg277", 1323 | "P_bg278", 1324 | bias=True, 1325 | ), 1326 | **dense( 1327 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q", 1328 | "P_bg279", 1329 | "P_bg280", 1330 | bias=False, 1331 | ), 1332 | **dense( 1333 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k", 1334 | "P_bg281", 1335 | "P_bg282", 1336 | bias=False, 1337 | ), 1338 | **dense( 1339 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v", 1340 | "P_bg281", 1341 | "P_bg282", 1342 | bias=False, 1343 | ), 1344 | **dense( 1345 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0", 1346 | "P_bg283", 1347 | "P_bg284", 1348 | bias=True, 1349 | ), 1350 | **norm( 1351 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1", 1352 | "P_bg284", 1353 | ), 1354 | **norm( 1355 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2", 1356 | "P_bg284", 1357 | ), 1358 | **norm( 1359 | "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3", 1360 | "P_bg284", 1361 | ), 1362 | **conv( 1363 | "model.diffusion_model.output_blocks.9.1.proj_out", "P_bg284", "P_bg285" 1364 | ), 1365 | **easyblock( 1366 | "model.diffusion_model.output_blocks.10.0", "P_bg286", "P_bg287" 1367 | ), 1368 | **conv( 1369 | "model.diffusion_model.output_blocks.10.0.skip_connection", 1370 | "P_bg288", 1371 | "P_bg289", 1372 | ), 1373 | **norm("model.diffusion_model.output_blocks.10.1.norm", "P_bg290"), 1374 | **conv( 1375 | "model.diffusion_model.output_blocks.10.1.proj_in", "P_bg290", "P_bg291" 1376 | ), 1377 | **dense( 1378 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q", 1379 | "P_bg292", 1380 | "P_bg293", 1381 | bias=False, 1382 | ), 1383 | **dense( 1384 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k", 1385 | "P_bg292", 1386 | "P_bg293", 1387 | bias=False, 1388 | ), 1389 | **dense( 1390 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v", 1391 | "P_bg292", 1392 | "P_bg293", 1393 | bias=False, 1394 | ), 1395 | **dense( 1396 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0", 1397 | "P_bg292", 1398 | "P_bg293", 1399 | bias=True, 1400 | ), 1401 | **dense( 1402 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj", 1403 | "P_b294", 1404 | "P_bg295", 1405 | bias=True, 1406 | ), 1407 | **dense( 1408 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2", 1409 | "P_bg296", 1410 | "P_bg297", 1411 | bias=True, 1412 | ), 1413 | **dense( 1414 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q", 1415 | "P_bg298", 1416 | "P_bg299", 1417 | bias=False, 1418 | ), 1419 | **dense( 1420 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k", 1421 | "P_bg300", 1422 | "P_bg301", 1423 | bias=False, 1424 | ), 1425 | **dense( 1426 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v", 1427 | "P_bg300", 1428 | "P_bg301", 1429 | bias=False, 1430 | ), 1431 | **dense( 1432 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0", 1433 | "P_bg302", 1434 | "P_bg303", 1435 | bias=True, 1436 | ), 1437 | **norm( 1438 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1", 1439 | "P_bg303", 1440 | ), 1441 | **norm( 1442 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2", 1443 | "P_bg303", 1444 | ), 1445 | **norm( 1446 | "model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3", 1447 | "P_bg303", 1448 | ), 1449 | **conv( 1450 | "model.diffusion_model.output_blocks.10.1.proj_out", 1451 | "P_bg303", 1452 | "P_bg304", 1453 | ), 1454 | **easyblock( 1455 | "model.diffusion_model.output_blocks.11.0", "P_bg305", "P_bg306" 1456 | ), 1457 | **conv( 1458 | "model.diffusion_model.output_blocks.11.0.skip_connection", 1459 | "P_bg307", 1460 | "P_bg308", 1461 | ), 1462 | **norm("model.diffusion_model.output_blocks.11.1.norm", "P_bg309"), 1463 | **conv( 1464 | "model.diffusion_model.output_blocks.11.1.proj_in", "P_bg309", "P_bg310" 1465 | ), 1466 | **dense( 1467 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q", 1468 | "P_bg311", 1469 | "P_bg312", 1470 | bias=False, 1471 | ), 1472 | **dense( 1473 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k", 1474 | "P_bg311", 1475 | "P_bg312", 1476 | bias=False, 1477 | ), 1478 | **dense( 1479 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v", 1480 | "P_bg311", 1481 | "P_bg312", 1482 | bias=False, 1483 | ), 1484 | **dense( 1485 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0", 1486 | "P_bg311", 1487 | "P_bg312", 1488 | bias=True, 1489 | ), 1490 | **dense( 1491 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj", 1492 | "P_bg313", 1493 | "P_bg314", 1494 | bias=True, 1495 | ), 1496 | **dense( 1497 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2", 1498 | "P_bg315", 1499 | "P_bg316", 1500 | bias=True, 1501 | ), 1502 | **dense( 1503 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q", 1504 | "P_bg317", 1505 | "P_bg318", 1506 | bias=False, 1507 | ), 1508 | **dense( 1509 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k", 1510 | "P_bg319", 1511 | "P_bg320", 1512 | bias=False, 1513 | ), 1514 | **dense( 1515 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v", 1516 | "P_bg319", 1517 | "P_bg320", 1518 | bias=False, 1519 | ), 1520 | **dense( 1521 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0", 1522 | "P_bg321", 1523 | "P_bg322", 1524 | bias=True, 1525 | ), 1526 | **norm( 1527 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1", 1528 | "P_bg322", 1529 | ), 1530 | **norm( 1531 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2", 1532 | "P_bg322", 1533 | ), 1534 | **norm( 1535 | "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3", 1536 | "P_bg322", 1537 | ), 1538 | **conv( 1539 | "model.diffusion_model.output_blocks.11.1.proj_out", 1540 | "P_bg322", 1541 | "P_bg323", 1542 | ), 1543 | **norm("model.diffusion_model.out.0", "P_bg324"), 1544 | **conv("model.diffusion_model.out.2", "P_bg325", "P_bg326"), 1545 | # Text Encoder 1546 | # encoder down 1547 | **conv("first_stage_model.encoder.conv_in", "P_bg327", "P_bg328"), 1548 | **easyblock2("first_stage_model.encoder.down.0.block.0", "P_bg328"), 1549 | **easyblock2("first_stage_model.encoder.down.0.block.1", "P_bg328"), 1550 | **conv( 1551 | "first_stage_model.encoder.down.0.downsample.conv", "P_bg328", "P_bg329" 1552 | ), 1553 | **shortcutblock( 1554 | "first_stage_model.encoder.down.1.block.0", "P_bg330", "P_bg331" 1555 | ), 1556 | **easyblock2("first_stage_model.encoder.down.1.block.1", "P_bg331"), 1557 | **conv( 1558 | "first_stage_model.encoder.down.1.downsample.conv", "P_bg331", "P_bg332" 1559 | ), 1560 | **shortcutblock( 1561 | "first_stage_model.encoder.down.2.block.0", "P_bg332", "P_bg333" 1562 | ), 1563 | **easyblock2("first_stage_model.encoder.down.2.block.1", "P_bg333"), 1564 | **conv( 1565 | "first_stage_model.encoder.down.2.downsample.conv", "P_bg333", "P_bg334" 1566 | ), 1567 | **easyblock2("first_stage_model.encoder.down.3.block.0", "P_bg334"), 1568 | **easyblock2("first_stage_model.encoder.down.3.block.1", "P_bg334"), 1569 | # encoder mid-block 1570 | **easyblock2("first_stage_model.encoder.mid.block_1", "P_bg334"), 1571 | **norm("first_stage_model.encoder.mid.attn_1.norm", "P_bg334"), 1572 | **conv("first_stage_model.encoder.mid.attn_1.q", "P_bg334", "P_bg335"), 1573 | **conv("first_stage_model.encoder.mid.attn_1.k", "P_bg334", "P_bg335"), 1574 | **conv("first_stage_model.encoder.mid.attn_1.v", "P_bg334", "P_bg335"), 1575 | **conv( 1576 | "first_stage_model.encoder.mid.attn_1.proj_out", "P_bg335", "P_bg336" 1577 | ), 1578 | **easyblock2("first_stage_model.encoder.mid.block_2", "P_bg336"), 1579 | **norm("first_stage_model.encoder.norm_out", "P_bg337"), 1580 | **conv("first_stage_model.encoder.conv_out", "P_bg338", "P_bg339"), 1581 | **conv("first_stage_model.decoder.conv_in", "P_bg340", "P_bg341"), 1582 | # decoder mid-block 1583 | **easyblock2("first_stage_model.decoder.mid.block_1", "P_bg342"), 1584 | **norm("first_stage_model.decoder.mid.attn_1.norm", "P_bg342"), 1585 | **conv("first_stage_model.decoder.mid.attn_1.q", "P_bg342", "P_bg343"), 1586 | **conv("first_stage_model.decoder.mid.attn_1.k", "P_bg342", "P_bg343"), 1587 | **conv("first_stage_model.decoder.mid.attn_1.v", "P_bg342", "P_bg343"), 1588 | **conv( 1589 | "first_stage_model.decoder.mid.attn_1.proj_out", "P_bg343", "P_bg344" 1590 | ), 1591 | **easyblock2("first_stage_model.decoder.mid.block_2", "P_bg345"), 1592 | # decoder up 1593 | **shortcutblock( 1594 | "first_stage_model.decoder.up.0.block.0", "P_bg346", "P_bg347" 1595 | ), 1596 | **easyblock2("first_stage_model.decoder.up.0.block.1", "P_bg348"), 1597 | **easyblock2("first_stage_model.decoder.up.0.block.2", "P_bg349"), 1598 | **shortcutblock( 1599 | "first_stage_model.decoder.up.1.block.0", "P_bg350", "P_bg351" 1600 | ), 1601 | **easyblock2("first_stage_model.decoder.up.1.block.1", "P_bg352"), 1602 | **easyblock2("first_stage_model.decoder.up.1.block.2", "P_bg353"), 1603 | **conv( 1604 | "first_stage_model.decoder.up.1.upsample.conv", "P_bg353", "P_bg354" 1605 | ), 1606 | **easyblock2("first_stage_model.decoder.up.2.block.0", "P_bg355"), 1607 | **easyblock2("first_stage_model.decoder.up.2.block.1", "P_bg355"), 1608 | **easyblock2("first_stage_model.decoder.up.2.block.2", "P_bg355"), 1609 | **conv( 1610 | "first_stage_model.decoder.up.2.upsample.conv", "P_bg355", "P_bg356" 1611 | ), 1612 | **easyblock2("first_stage_model.decoder.up.3.block.0", "P_bg356"), 1613 | **easyblock2("first_stage_model.decoder.up.3.block.1", "P_bg356"), 1614 | **easyblock2("first_stage_model.decoder.up.3.block.2", "P_bg356"), 1615 | **conv( 1616 | "first_stage_model.decoder.up.3.upsample.conv", "P_bg356", "P_bg357" 1617 | ), 1618 | **norm("first_stage_model.decoder.norm_out", "P_bg358"), 1619 | **conv("first_stage_model.decoder.conv_out", "P_bg359", "P_bg360"), 1620 | **conv("first_stage_model.quant_conv", "P_bg361", "P_bg362"), 1621 | **conv("first_stage_model.post_quant_conv", "P_bg363", "P_bg364"), 1622 | **skip( 1623 | "cond_stage_model.transformer.text_model.embeddings.position_ids", 1624 | None, 1625 | None, 1626 | ), 1627 | **dense( 1628 | "cond_stage_model.transformer.text_model.embeddings.token_embedding", 1629 | "P_bg365", 1630 | "P_bg366", 1631 | bias=False, 1632 | ), 1633 | **dense( 1634 | "cond_stage_model.transformer.text_model.embeddings.token_embedding", 1635 | None, 1636 | None, 1637 | ), 1638 | **dense( 1639 | "cond_stage_model.transformer.text_model.embeddings.position_embedding", 1640 | "P_bg367", 1641 | "P_bg368", 1642 | bias=False, 1643 | ), 1644 | # cond stage text encoder 1645 | **dense( 1646 | "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj", 1647 | "P_bg369", 1648 | "P_bg370", 1649 | bias=True, 1650 | ), 1651 | **dense( 1652 | "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj", 1653 | "P_bg369", 1654 | "P_bg370", 1655 | bias=True, 1656 | ), 1657 | **dense( 1658 | "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj", 1659 | "P_bg369", 1660 | "P_bg370", 1661 | bias=True, 1662 | ), 1663 | **dense( 1664 | "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj", 1665 | "P_bg369", 1666 | "P_bg370", 1667 | bias=True, 1668 | ), 1669 | **norm( 1670 | "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1", 1671 | "P_bg370", 1672 | ), 1673 | **dense( 1674 | "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1", 1675 | "P_bg370", 1676 | "P_bg371", 1677 | bias=True, 1678 | ), 1679 | **dense( 1680 | "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2", 1681 | "P_bg371", 1682 | "P_bg372", 1683 | bias=True, 1684 | ), 1685 | **norm( 1686 | "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2", 1687 | "P_bg372", 1688 | ), 1689 | **dense( 1690 | "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj", 1691 | "P_bg372", 1692 | "P_bg373", 1693 | bias=True, 1694 | ), 1695 | **dense( 1696 | "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj", 1697 | "P_bg372", 1698 | "P_bg373", 1699 | bias=True, 1700 | ), 1701 | **dense( 1702 | "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj", 1703 | "P_bg372", 1704 | "P_bg373", 1705 | bias=True, 1706 | ), 1707 | **dense( 1708 | "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj", 1709 | "P_bg372", 1710 | "P_bg373", 1711 | bias=True, 1712 | ), 1713 | **norm( 1714 | "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1", 1715 | "P_bg373", 1716 | ), 1717 | **dense( 1718 | "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1", 1719 | "P_bg373", 1720 | "P_bg374", 1721 | bias=True, 1722 | ), 1723 | **dense( 1724 | "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2", 1725 | "P_bg374", 1726 | "P_bg375", 1727 | bias=True, 1728 | ), 1729 | **norm( 1730 | "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2", 1731 | "P_bg375", 1732 | ), 1733 | **dense( 1734 | "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj", 1735 | "P_bg375", 1736 | "P_bg376", 1737 | bias=True, 1738 | ), 1739 | **dense( 1740 | "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj", 1741 | "P_bg375", 1742 | "P_bg376", 1743 | bias=True, 1744 | ), 1745 | **dense( 1746 | "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj", 1747 | "P_bg375", 1748 | "P_bg376", 1749 | bias=True, 1750 | ), 1751 | **dense( 1752 | "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj", 1753 | "P_bg375", 1754 | "P_bg376", 1755 | bias=True, 1756 | ), 1757 | **norm( 1758 | "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1", 1759 | "P_bg376", 1760 | ), 1761 | **dense( 1762 | "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1", 1763 | "P_bg376", 1764 | "P_bg377", 1765 | bias=True, 1766 | ), 1767 | **dense( 1768 | "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2", 1769 | "P_bg377", 1770 | "P_bg378", 1771 | bias=True, 1772 | ), 1773 | **norm( 1774 | "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2", 1775 | "P_bg378", 1776 | ), 1777 | **dense( 1778 | "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj", 1779 | "P_bg378", 1780 | "P_bg379", 1781 | bias=True, 1782 | ), 1783 | **dense( 1784 | "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj", 1785 | "P_bg378", 1786 | "P_bg379", 1787 | bias=True, 1788 | ), 1789 | **dense( 1790 | "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj", 1791 | "P_bg378", 1792 | "P_bg379", 1793 | bias=True, 1794 | ), 1795 | **dense( 1796 | "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj", 1797 | "P_bg378", 1798 | "P_bg379", 1799 | bias=True, 1800 | ), 1801 | **norm( 1802 | "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1", 1803 | "P_bg379", 1804 | ), 1805 | **dense( 1806 | "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1", 1807 | "P_bg379", 1808 | "P_bg380", 1809 | bias=True, 1810 | ), 1811 | **dense( 1812 | "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2", 1813 | "P_bg380", 1814 | "P_b381", 1815 | bias=True, 1816 | ), 1817 | **norm( 1818 | "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2", 1819 | "P_bg381", 1820 | ), 1821 | **dense( 1822 | "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj", 1823 | "P_bg381", 1824 | "P_bg382", 1825 | bias=True, 1826 | ), 1827 | **dense( 1828 | "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj", 1829 | "P_bg381", 1830 | "P_bg382", 1831 | bias=True, 1832 | ), 1833 | **dense( 1834 | "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj", 1835 | "P_bg381", 1836 | "P_bg382", 1837 | bias=True, 1838 | ), 1839 | **dense( 1840 | "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj", 1841 | "P_bg381", 1842 | "P_bg382", 1843 | bias=True, 1844 | ), 1845 | **norm( 1846 | "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1", 1847 | "P_bg382", 1848 | ), 1849 | **dense( 1850 | "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1", 1851 | "P_bg382", 1852 | "P_bg383", 1853 | bias=True, 1854 | ), 1855 | **dense( 1856 | "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2", 1857 | "P_bg383", 1858 | "P_bg384", 1859 | bias=True, 1860 | ), 1861 | **norm( 1862 | "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2", 1863 | "P_bg384", 1864 | ), 1865 | **dense( 1866 | "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj", 1867 | "P_bg384", 1868 | "P_bg385", 1869 | bias=True, 1870 | ), 1871 | **dense( 1872 | "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj", 1873 | "P_bg384", 1874 | "P_bg385", 1875 | bias=True, 1876 | ), 1877 | **dense( 1878 | "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj", 1879 | "P_bg384", 1880 | "P_bg385", 1881 | bias=True, 1882 | ), 1883 | **dense( 1884 | "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj", 1885 | "P_bg384", 1886 | "P_bg385", 1887 | bias=True, 1888 | ), 1889 | **norm( 1890 | "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1", 1891 | "P_bg385", 1892 | ), 1893 | **dense( 1894 | "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1", 1895 | "P_bg385", 1896 | "P_bg386", 1897 | bias=True, 1898 | ), 1899 | **dense( 1900 | "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2", 1901 | "P_bg386", 1902 | "P_bg387", 1903 | bias=True, 1904 | ), 1905 | **norm( 1906 | "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2", 1907 | "P_bg387", 1908 | ), 1909 | **dense( 1910 | "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj", 1911 | "P_bg387", 1912 | "P_bg388", 1913 | bias=True, 1914 | ), 1915 | **dense( 1916 | "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj", 1917 | "P_bg387", 1918 | "P_bg388", 1919 | bias=True, 1920 | ), 1921 | **dense( 1922 | "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj", 1923 | "P_bg387", 1924 | "P_bg388", 1925 | bias=True, 1926 | ), 1927 | **dense( 1928 | "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj", 1929 | "P_bg387", 1930 | "P_bg388", 1931 | bias=True, 1932 | ), 1933 | **norm( 1934 | "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1", 1935 | "P_bg389", 1936 | ), 1937 | **dense( 1938 | "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1", 1939 | "P_bg389", 1940 | "P_bg390", 1941 | bias=True, 1942 | ), 1943 | **dense( 1944 | "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2", 1945 | "P_bg390", 1946 | "P_bg391", 1947 | bias=True, 1948 | ), 1949 | **norm( 1950 | "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2", 1951 | "P_bg391", 1952 | ), 1953 | **dense( 1954 | "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj", 1955 | "P_bg391", 1956 | "P_bg392", 1957 | bias=True, 1958 | ), 1959 | **dense( 1960 | "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj", 1961 | "P_bg391", 1962 | "P_bg392", 1963 | bias=True, 1964 | ), 1965 | **dense( 1966 | "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj", 1967 | "P_bg391", 1968 | "P_bg392", 1969 | bias=True, 1970 | ), 1971 | **dense( 1972 | "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj", 1973 | "P_bg391", 1974 | "P_bg392", 1975 | bias=True, 1976 | ), 1977 | **norm( 1978 | "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1", 1979 | "P_bg392", 1980 | ), 1981 | **dense( 1982 | "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1", 1983 | "P_bg392", 1984 | "P_bg393", 1985 | bias=True, 1986 | ), 1987 | **dense( 1988 | "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2", 1989 | "P_bg393", 1990 | "P_bg394", 1991 | bias=True, 1992 | ), 1993 | **norm( 1994 | "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2", 1995 | "P_bg394", 1996 | ), 1997 | **dense( 1998 | "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj", 1999 | "P_bg394", 2000 | "P_bg395", 2001 | bias=True, 2002 | ), 2003 | **dense( 2004 | "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj", 2005 | "P_bg394", 2006 | "P_bg395", 2007 | bias=True, 2008 | ), 2009 | **dense( 2010 | "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj", 2011 | "P_bg394", 2012 | "P_bg395", 2013 | bias=True, 2014 | ), 2015 | **dense( 2016 | "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj", 2017 | "P_bg394", 2018 | "P_bg395", 2019 | bias=True, 2020 | ), 2021 | **norm( 2022 | "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1", 2023 | "P_bg395", 2024 | ), 2025 | **dense( 2026 | "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1", 2027 | "P_bg395", 2028 | "P_bg396", 2029 | bias=True, 2030 | ), 2031 | **dense( 2032 | "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2", 2033 | "P_bg396", 2034 | "P_bg397", 2035 | bias=True, 2036 | ), 2037 | **norm( 2038 | "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2", 2039 | "P_bg397", 2040 | ), 2041 | **dense( 2042 | "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj", 2043 | "P_bg397", 2044 | "P_bg398", 2045 | bias=True, 2046 | ), 2047 | **dense( 2048 | "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj", 2049 | "P_bg397", 2050 | "P_bg398", 2051 | bias=True, 2052 | ), 2053 | **dense( 2054 | "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj", 2055 | "P_bg397", 2056 | "P_bg398", 2057 | bias=True, 2058 | ), 2059 | **dense( 2060 | "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj", 2061 | "P_bg397", 2062 | "P_bg398", 2063 | bias=True, 2064 | ), 2065 | **norm( 2066 | "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1", 2067 | "P_bg398", 2068 | ), 2069 | **dense( 2070 | "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1", 2071 | "P_bg398", 2072 | "P_bg399", 2073 | bias=True, 2074 | ), 2075 | **dense( 2076 | "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2", 2077 | "P_bg400", 2078 | "P_bg401", 2079 | bias=True, 2080 | ), 2081 | **norm( 2082 | "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2", 2083 | "P_bg401", 2084 | ), 2085 | **dense( 2086 | "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj", 2087 | "P_bg401", 2088 | "P_bg402", 2089 | bias=True, 2090 | ), 2091 | **dense( 2092 | "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj", 2093 | "P_bg401", 2094 | "P_bg402", 2095 | bias=True, 2096 | ), 2097 | **dense( 2098 | "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj", 2099 | "P_bg401", 2100 | "P_bg402", 2101 | bias=True, 2102 | ), 2103 | **dense( 2104 | "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj", 2105 | "P_bg401", 2106 | "P_bg402", 2107 | bias=True, 2108 | ), 2109 | **norm( 2110 | "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1", 2111 | "P_bg402", 2112 | ), 2113 | **dense( 2114 | "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1", 2115 | "P_bg402", 2116 | "P_bg403", 2117 | bias=True, 2118 | ), 2119 | **dense( 2120 | "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2", 2121 | "P_bg403", 2122 | "P_bg404", 2123 | bias=True, 2124 | ), 2125 | **norm( 2126 | "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2", 2127 | "P_bg404", 2128 | ), 2129 | **dense( 2130 | "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj", 2131 | "P_bg404", 2132 | "P_bg405", 2133 | bias=True, 2134 | ), 2135 | **dense( 2136 | "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj", 2137 | "P_bg404", 2138 | "P_bg405", 2139 | bias=True, 2140 | ), 2141 | **dense( 2142 | "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj", 2143 | "P_bg404", 2144 | "P_bg405", 2145 | bias=True, 2146 | ), 2147 | **dense( 2148 | "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj", 2149 | "P_bg404", 2150 | "P_bg405", 2151 | bias=True, 2152 | ), 2153 | **norm( 2154 | "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1", 2155 | "P_bg405", 2156 | ), 2157 | **dense( 2158 | "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1", 2159 | "P_bg405", 2160 | "P_bg406", 2161 | bias=True, 2162 | ), 2163 | **dense( 2164 | "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2", 2165 | "P_bg406", 2166 | "P_bg407", 2167 | bias=True, 2168 | ), 2169 | **norm( 2170 | "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2", 2171 | "P_bg407", 2172 | ), 2173 | **norm( 2174 | "cond_stage_model.transformer.text_model.final_layer_norm", "P_bg407" 2175 | ), 2176 | } 2177 | ) 2178 | 2179 | 2180 | def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None): 2181 | """Get parameter `k` from `params`, with the permutations applied.""" 2182 | w = params[k] 2183 | for axis, p in enumerate(ps.axes_to_perm[k]): 2184 | # Skip the axis we're trying to permute. 2185 | if axis == except_axis: 2186 | continue 2187 | 2188 | # None indicates that there is no permutation relevant to that axis. 2189 | if p: 2190 | w = torch.index_select(w, axis, perm[p].int()) 2191 | 2192 | return w 2193 | 2194 | 2195 | def apply_permutation(ps: PermutationSpec, perm, params): 2196 | """Apply a `perm` to `params`.""" 2197 | return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()} 2198 | 2199 | 2200 | def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha): 2201 | for k in model_a: 2202 | try: 2203 | perm_params = get_permuted_param(ps, perm, k, model_a) 2204 | model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params 2205 | except RuntimeError: # dealing with pix2pix and inpainting models 2206 | continue 2207 | return model_a 2208 | 2209 | 2210 | def inner_matching( 2211 | n, 2212 | ps, 2213 | p, 2214 | params_a, 2215 | params_b, 2216 | usefp16, 2217 | progress, 2218 | number, 2219 | linear_sum, 2220 | perm, 2221 | device, 2222 | ): 2223 | A = torch.zeros((n, n), dtype=torch.float16) if usefp16 else torch.zeros((n, n)) 2224 | A = A.to(device) 2225 | 2226 | for wk, axis in ps.perm_to_axes[p]: 2227 | w_a = params_a[wk] 2228 | w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis) 2229 | w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)).to(device) 2230 | w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1)).T.to(device) 2231 | 2232 | if usefp16: 2233 | w_a = w_a.half().to(device) 2234 | w_b = w_b.half().to(device) 2235 | 2236 | try: 2237 | A += torch.matmul(w_a, w_b) 2238 | except RuntimeError: 2239 | A += torch.matmul(torch.dequantize(w_a), torch.dequantize(w_b)) 2240 | 2241 | A = A.cpu() 2242 | ri, ci = linear_sum_assignment(A.detach().numpy(), maximize=True) 2243 | A = A.to(device) 2244 | 2245 | assert (torch.tensor(ri) == torch.arange(len(ri))).all() 2246 | 2247 | eye_tensor = torch.eye(n).to(device) 2248 | 2249 | oldL = torch.vdot( 2250 | torch.flatten(A).float(), torch.flatten(eye_tensor[perm[p].long()]) 2251 | ) 2252 | newL = torch.vdot(torch.flatten(A).float(), torch.flatten(eye_tensor[ci, :])) 2253 | 2254 | if usefp16: 2255 | oldL = oldL.half() 2256 | newL = newL.half() 2257 | 2258 | if newL - oldL != 0: 2259 | linear_sum += abs((newL - oldL).item()) 2260 | number += 1 2261 | logging.info(f" permutation {p}: {newL - oldL}") 2262 | 2263 | progress = progress or newL > oldL + 1e-12 2264 | 2265 | perm[p] = torch.Tensor(ci).to(device) 2266 | 2267 | return linear_sum, number, perm, progress 2268 | 2269 | 2270 | def weight_matching( 2271 | ps: PermutationSpec, 2272 | params_a, 2273 | params_b, 2274 | max_iter=1, 2275 | init_perm=None, 2276 | usefp16=False, 2277 | device="cpu", 2278 | ): 2279 | perm_sizes = { 2280 | p: params_a[axes[0][0]].shape[axes[0][1]] 2281 | for p, axes in ps.perm_to_axes.items() 2282 | if axes[0][0] in params_a.keys() 2283 | } 2284 | perm = {} 2285 | perm = ( 2286 | {p: torch.arange(n).to(device) for p, n in perm_sizes.items()} 2287 | if init_perm is None 2288 | else init_perm 2289 | ) 2290 | 2291 | linear_sum = 0 2292 | number = 0 2293 | 2294 | special_layers = ["P_bg324", "P_bg358", "P_bg337"] 2295 | for _ in range(max_iter): 2296 | progress = False 2297 | shuffle(special_layers) 2298 | for p in special_layers: 2299 | n = perm_sizes[p] 2300 | 2301 | linear_sum, number, perm, progress = inner_matching( 2302 | n, 2303 | ps, 2304 | p, 2305 | params_a, 2306 | params_b, 2307 | usefp16, 2308 | progress, 2309 | number, 2310 | linear_sum, 2311 | perm, 2312 | device, 2313 | ) 2314 | if not progress: 2315 | break 2316 | 2317 | average = linear_sum / number if number > 0 else 0 2318 | return (perm, average) 2319 | -------------------------------------------------------------------------------- /sd_meh/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import logging 3 | 4 | from sd_meh import merge_methods 5 | from sd_meh.merge import NUM_TOTAL_BLOCKS 6 | from sd_meh.presets import BLOCK_WEIGHTS_PRESETS 7 | 8 | MERGE_METHODS = dict(inspect.getmembers(merge_methods, inspect.isfunction)) 9 | BETA_METHODS = [ 10 | name 11 | for name, fn in MERGE_METHODS.items() 12 | if "beta" in inspect.getfullargspec(fn)[0] 13 | ] 14 | 15 | 16 | def compute_weights(weights, base): 17 | if not weights: 18 | return [base] * NUM_TOTAL_BLOCKS 19 | 20 | if "," not in weights: 21 | return weights 22 | 23 | w_alpha = list(map(float, weights.split(","))) 24 | if len(w_alpha) == NUM_TOTAL_BLOCKS: 25 | return w_alpha 26 | 27 | 28 | def assemble_weights_and_bases(preset, weights, base, greek_letter): 29 | logging.info(f"Assembling {greek_letter} w&b") 30 | if preset: 31 | logging.info(f"Using {preset} preset") 32 | base, *weights = BLOCK_WEIGHTS_PRESETS[preset] 33 | bases = {greek_letter: base} 34 | weights = {greek_letter: compute_weights(weights, base)} 35 | 36 | logging.info(f"base_{greek_letter}: {bases[greek_letter]}") 37 | logging.info(f"{greek_letter} weights: {weights[greek_letter]}") 38 | 39 | return weights, bases 40 | 41 | 42 | def interpolate_presets( 43 | weights, bases, weights_b, bases_b, greek_letter, presets_lambda 44 | ): 45 | logging.info(f"Interpolating {greek_letter} w&b") 46 | for i, e in enumerate(weights[greek_letter]): 47 | weights[greek_letter][i] = ( 48 | 1 - presets_lambda 49 | ) * e + presets_lambda * weights_b[greek_letter][i] 50 | 51 | bases[greek_letter] = (1 - presets_lambda) * bases[ 52 | greek_letter 53 | ] + presets_lambda * bases_b[greek_letter] 54 | 55 | logging.info(f"Interpolated base_{greek_letter}: {bases[greek_letter]}") 56 | logging.info(f"Interpolated {greek_letter} weights: {weights[greek_letter]}") 57 | 58 | return weights, bases 59 | 60 | 61 | def weights_and_bases( 62 | merge_mode, 63 | weights_alpha, 64 | base_alpha, 65 | block_weights_preset_alpha, 66 | weights_beta, 67 | base_beta, 68 | block_weights_preset_beta, 69 | block_weights_preset_alpha_b, 70 | block_weights_preset_beta_b, 71 | presets_alpha_lambda, 72 | presets_beta_lambda, 73 | ): 74 | weights, bases = assemble_weights_and_bases( 75 | block_weights_preset_alpha, 76 | weights_alpha, 77 | base_alpha, 78 | "alpha", 79 | ) 80 | 81 | if block_weights_preset_alpha_b: 82 | logging.info("Computing w&b for alpha interpolation") 83 | weights_b, bases_b = assemble_weights_and_bases( 84 | block_weights_preset_alpha_b, 85 | None, 86 | None, 87 | "alpha", 88 | ) 89 | weights, bases = interpolate_presets( 90 | weights, 91 | bases, 92 | weights_b, 93 | bases_b, 94 | "alpha", 95 | presets_alpha_lambda, 96 | ) 97 | 98 | if merge_mode in BETA_METHODS: 99 | weights_beta, bases_beta = assemble_weights_and_bases( 100 | block_weights_preset_beta, 101 | weights_beta, 102 | base_beta, 103 | "beta", 104 | ) 105 | 106 | if block_weights_preset_beta_b: 107 | logging.info("Computing w&b for beta interpolation") 108 | weights_b, bases_b = assemble_weights_and_bases( 109 | block_weights_preset_beta_b, 110 | None, 111 | None, 112 | "beta", 113 | ) 114 | weights, bases = interpolate_presets( 115 | weights, 116 | bases, 117 | weights_b, 118 | bases_b, 119 | "beta", 120 | presets_beta_lambda, 121 | ) 122 | 123 | weights |= weights_beta 124 | bases |= bases_beta 125 | 126 | return weights, bases 127 | --------------------------------------------------------------------------------