├── .github └── FUNDING.yml ├── .gitignore ├── LICENSE ├── LICENSE-pytorch.txt ├── LICENSE-safetensors.txt ├── README.md ├── examples ├── add_perpendicular.py ├── basic_merge.py ├── binomial_dropout_merge.py ├── custom_merge_method.py ├── in_memory_merge.py ├── lora_merge.py ├── n_average.py ├── recipe_deserialize.py ├── recipe_serialize.py ├── rotate.py ├── split_unet_text_encoder.py └── ties_add_difference.py ├── media ├── did-you-see-something.PNG └── memory-gone.PNG ├── pyproject.toml ├── requirements.txt ├── scripts ├── convert_ckpt_to_safetensors.py ├── dump_keys.py └── yaml_from_safetensors.py ├── sd_mecha ├── __init__.py ├── conversion.py ├── extensions │ ├── __init__.py │ ├── builtin │ │ ├── lycoris.py │ │ ├── merge_methods │ │ │ ├── __init__.py │ │ │ ├── clamp.py │ │ │ ├── cosine.py │ │ │ ├── crossover.py │ │ │ ├── ema.py │ │ │ ├── linear.py │ │ │ ├── logistics.py │ │ │ ├── slicing.py │ │ │ ├── svd.py │ │ │ └── ties_sum.py │ │ └── model_configs │ │ │ ├── __init__.py │ │ │ ├── convert_flux.py │ │ │ ├── convert_huggingface_sd_vae_to_original.py │ │ │ ├── convert_sd1_blocks.py │ │ │ ├── convert_sd1_kohya_to_original.py │ │ │ ├── convert_sdxl_blocks.py │ │ │ ├── convert_sdxl_diffusers_unet_to_original.py │ │ │ ├── convert_sdxl_kohya_to_original.py │ │ │ ├── flux-flux.yaml │ │ │ ├── flux-flux_diffuser_only.yaml │ │ │ ├── sd1-kohya.yaml │ │ │ ├── sd1-ldm.yaml │ │ │ ├── sd1-supermerger_blocks.yaml │ │ │ ├── sd3-comfyui.yaml │ │ │ ├── sdxl-diffusers_unet_only.yaml │ │ │ ├── sdxl-kohya.yaml │ │ │ ├── sdxl-sgm.yaml │ │ │ └── sdxl-supermerger_blocks.yaml │ ├── merge_methods.py │ ├── merge_spaces.py │ ├── model_configs.py │ └── model_formats.py ├── helpers.py ├── merge_method_wrappers.py ├── merging.py ├── recipe_nodes.py ├── serialization.py ├── streaming.py └── typing_.py └── tests ├── extensions ├── __init__.py └── test_merge_methods.py └── merge_methods ├── __init__.py ├── test_della.py ├── test_geometric_median.py ├── test_modelstock.py └── test_ties.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: ljleb 4 | patreon: ljleb 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | __pycache__ 3 | /sd_mecha.egg-info 4 | /build 5 | /dist 6 | /config 7 | /venv 8 | /model_configs/venvs 9 | 10 | /examples/*_test.py 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ljleb 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE-pytorch.txt: -------------------------------------------------------------------------------- 1 | All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright 10 | notice, this list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the distribution. 12 | 13 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 14 | and IDIAP Research Institute nor the names of its contributors may be 15 | used to endorse or promote products derived from this software without 16 | specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /LICENSE-safetensors.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sd-mecha 2 | 3 | [![PyPI version](https://badge.fury.io/py/sd-mecha.svg)](https://badge.fury.io/py/sd-mecha) 4 | [![Discord Server](https://dcbadge.vercel.app/api/server/2EPaw6fxxm?style=flat)](https://discord.gg/invite/2EPaw6fxxm) 5 | 6 | ```python 7 | import sd_mecha 8 | 9 | # create the merge plan 10 | a = sd_mecha.model("path/to/model_a.safetensors") 11 | b = sd_mecha.model("path/to/model_b.safetensors") 12 | recipe = sd_mecha.weighted_sum(a, b, alpha=0.5) 13 | 14 | # merge! 15 | sd_mecha.merge(recipe, output="path/to/model_out.safetensors") 16 | ``` 17 | 18 | sd-mecha is a general memory-efficient model merging library. It can merge *any* model: 19 | - Diffusion models 20 | - LLMs 21 | - VLMs 22 | - Aesthetic scorers 23 | - etc. 24 | 25 | ## Features 26 | 27 | - Memory efficient model merging: merge a very large number of models in a single execution 28 | - Textual and interpretable format for storage and execution (.mecha) 29 | - Extensible library interface: 30 | - add custom models 31 | - add custom merge methods 32 | - Builtin support for popular diffusion models: 33 | - Stable Diffusion 1.5 34 | - Stable Diffusion XL 35 | - Stable Diffusion 3 36 | - FLUX Schnell/Dev 37 | - Merge LyCORIS networks together and to checkpoints 38 | - Block-wise hyperparameters for precise control of blocks (aka MBW) 39 | 40 | ## Install 41 | 42 | ```commandline 43 | pip install sd-mecha 44 | ``` 45 | 46 | Make sure to install the appropriate release of [`torch`](https://pytorch.org/get-started/locally/) to get the best performance. 47 | 48 | ## Usage 49 | 50 | ### Merge models 51 | 52 | To merge models, mecha uses recipes. 53 | A recipe is a list of instructions that describes the exact steps needed to derive a state dict from inputs. 54 | 55 | Here's an example script that merges three models: 56 | 57 | ```python 58 | import sd_mecha 59 | 60 | # create the merge plan 61 | model_a = sd_mecha.model("path/to/model_a.safetensors") 62 | model_b = sd_mecha.model("path/to/model_b.safetensors") 63 | recipe = sd_mecha.weighted_sum(model_a, model_b, alpha=0.5) 64 | 65 | # merge! 66 | sd_mecha.merge(recipe, output="path/to/model_out.safetensors") 67 | ``` 68 | 69 | See the [examples](/examples) directory for more examples. 70 | 71 | ### Get Model-Specific Information 72 | 73 | The library uses a "model config" to designate any specific set of keys of a certain shape. 74 | 75 | It is possible to list all available model configs through the `sd_mecha.extensions.model_configs` module: 76 | 77 | ```python 78 | from sd_mecha.extensions import model_configs 79 | 80 | all_configs = model_configs.get_all() 81 | 82 | print([config.identifier for config in all_configs]) 83 | # ["sd1-ldm-base", "sdxl-sgm-base", "sd3-sgm-base", ...] 84 | ``` 85 | 86 | A *component* of a model config is a subset of keys of the config that belong to the same logical group. 87 | For example, all keys starting with "first_stage_model." in Stable Diffusion models belong to the component "vae". 88 | 89 | It is possible to query the different components of a model config: 90 | 91 | ```python 92 | from sd_mecha.extensions import model_configs 93 | 94 | config = model_configs.resolve("sd1-ldm") 95 | for component_id, component in config.components().items(): 96 | # component.keys contains the state dict keys that the component owns 97 | print(f"{component_id}") 98 | 99 | # this prints: 100 | # clip_l 101 | # vae 102 | # diffuser 103 | ``` 104 | 105 | ## Motivation 106 | 107 | Keeping track of full merge recipes has always been a problem for me. 108 | I needed something that allows to store merge recipes in a readable format while also being executable. 109 | I also needed something that allows to fully merge an entire tree of models without having to save intermediate models to disk. 110 | 111 | Typically, mergers load all models in memory before initiating the merge process. 112 | This can be very inefficient when the merge focuses on each key individually: 113 | 114 | ![image of typical merge graph](/media/memory-gone.PNG) 115 | 116 | sd-mecha doesn't have this problem as it saves keys as soon as it can: 117 | 118 | ![image of sd-mecha merge graph](/media/did-you-see-something.PNG) 119 | 120 | This allows to merge a very large number of models simultaneously on low-end hardware. 121 | -------------------------------------------------------------------------------- /examples/add_perpendicular.py: -------------------------------------------------------------------------------- 1 | import sd_mecha 2 | sd_mecha.set_log_level() 3 | 4 | 5 | recipe = sd_mecha.add_perpendicular( 6 | sd_mecha.model("ghostmix_v20Bakedvae.safetensors"), 7 | sd_mecha.model("dreamshaper_332BakedVaeClipFix.safetensors"), 8 | sd_mecha.model("pure/v1-5-pruned.safetensors"), 9 | ) 10 | 11 | sd_mecha.merge(recipe, output="result.safetensors") 12 | -------------------------------------------------------------------------------- /examples/basic_merge.py: -------------------------------------------------------------------------------- 1 | import sd_mecha 2 | sd_mecha.set_log_level() 3 | 4 | 5 | # plan a simple weighted sum 6 | a = sd_mecha.model("ghostmix_v20Bakedvae.safetensors") 7 | b = sd_mecha.model("dreamshaper_332BakedVaeClipFix.safetensors") 8 | recipe = sd_mecha.weighted_sum(a, b) 9 | 10 | # perform the entire merge plan and save to output path 11 | sd_mecha.merge(recipe, output="output.safetensors") 12 | -------------------------------------------------------------------------------- /examples/binomial_dropout_merge.py: -------------------------------------------------------------------------------- 1 | import sd_mecha 2 | sd_mecha.set_log_level() 3 | 4 | recipe = sd_mecha.dropout( 5 | sd_mecha.model("pure/v1-5-pruned.safetensors"), 6 | sd_mecha.model("ghostmix_v20Bakedvae.safetensors"), 7 | sd_mecha.model("dreamshaper_332BakedVaeClipFix.safetensors"), 8 | probability=0.9, alpha=0.5, seed=0, 9 | ) 10 | sd = sd_mecha.merge(recipe) 11 | -------------------------------------------------------------------------------- /examples/custom_merge_method.py: -------------------------------------------------------------------------------- 1 | import sd_mecha 2 | import torch 3 | from torch import Tensor 4 | from sd_mecha.extensions.merge_methods import merge_method, Parameter, Return 5 | 6 | 7 | # sets loglevel to INFO. some operations will report extra detail through stdout/stderr 8 | sd_mecha.set_log_level() 9 | 10 | 11 | # define a custom merge method 12 | # `@merge_method` converts the decorated function to work with the merge method API 13 | @merge_method 14 | def custom_sum( 15 | # Each positional argument is a single tensor from one of the input models. 16 | # Merge methods are called once for each key that all input models have in common. 17 | a: Parameter(Tensor), 18 | b: Parameter(Tensor), 19 | alpha: Parameter(Tensor) = 0.5, # params with a default value are automatically in "param" merge space 20 | *, 21 | beta: Parameter(Tensor, merge_space="param"), 22 | # extra info or metadata is passed in **kwargs if it is present 23 | # this includes the name of the key currently being merged, a cache mechanism, and a few more things 24 | # add **kwargs to receive this extra information:` 25 | **kwargs, 26 | ) -> Return(Tensor): 27 | weighted_sum = (1-alpha)*a + alpha*b 28 | 29 | # just for the sake of the example, let's add noise to the sum 30 | return (1 - beta) * weighted_sum + beta * torch.randn_like(weighted_sum) 31 | 32 | 33 | # plan our custom weighted sum 34 | recipe = custom_sum( 35 | # this merge uses the "sd1-ldm" model config. 36 | # the config is automatically inferred, there is no need to pass its identifier here 37 | sd_mecha.model("ghostmix_v20Bakedvae.safetensors"), 38 | sd_mecha.model("dreamshaper_332BakedVaeClipFix.safetensors"), 39 | # merge params, literal values (int, float, str, bool, None) are broadcasted to the entire state dict 40 | alpha=0.6, 41 | beta=0.1, 42 | ) 43 | 44 | # perform the entire merge plan and save to output path 45 | sd_mecha.merge(recipe, output="custom_merge.safetensors") 46 | -------------------------------------------------------------------------------- /examples/in_memory_merge.py: -------------------------------------------------------------------------------- 1 | import sd_mecha 2 | sd_mecha.set_log_level() 3 | 4 | 5 | recipe = sd_mecha.weighted_sum( 6 | sd_mecha.model("ghostmix_v20Bakedvae.safetensors"), 7 | sd_mecha.model("dreamshaper_332BakedVaeClipFix.safetensors"), 8 | ) 9 | 10 | sd = sd_mecha.merge(recipe) 11 | -------------------------------------------------------------------------------- /examples/lora_merge.py: -------------------------------------------------------------------------------- 1 | import sd_mecha 2 | sd_mecha.set_log_level() 3 | 4 | 5 | base = sd_mecha.model("ghostmix_v20Bakedvae.safetensors") 6 | lora = sd_mecha.convert(sd_mecha.model("head-mounted display3-000007.safetensors"), base) 7 | recipe = sd_mecha.add_difference(base, lora, alpha=1.0) 8 | 9 | sd = sd_mecha.merge(recipe) 10 | -------------------------------------------------------------------------------- /examples/n_average.py: -------------------------------------------------------------------------------- 1 | import sd_mecha 2 | sd_mecha.set_log_level() 3 | 4 | 5 | # 5 example models to average together 6 | models = [ 7 | "ghostmix_v20Bakedvae", 8 | "dreamshaper_332BakedVaeClipFix", 9 | "deliberate_v2", 10 | "darkSushi25D25D_v20", 11 | "CounterfeitV30_v30", 12 | ] 13 | 14 | recipe = sd_mecha.n_average(*models) 15 | sd = sd_mecha.merge(recipe) 16 | -------------------------------------------------------------------------------- /examples/recipe_deserialize.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import sd_mecha 3 | sd_mecha.set_log_level() 4 | 5 | 6 | recipe_path = pathlib.Path(__file__).parent / 'recipes' / "test_split_unet_text_encoder.mecha" 7 | with open(recipe_path) as f: 8 | recipe = sd_mecha.deserialize(f.readlines()) 9 | 10 | 11 | sd = sd_mecha.merge(recipe) 12 | -------------------------------------------------------------------------------- /examples/recipe_serialize.py: -------------------------------------------------------------------------------- 1 | import sd_mecha 2 | sd_mecha.set_log_level() 3 | 4 | 5 | text_encoder_recipe = sd_mecha.add_perpendicular( 6 | sd_mecha.model("ghostmix_v20Bakedvae.safetensors"), 7 | sd_mecha.model("dreamshaper_332BakedVaeClipFix.safetensors"), 8 | sd_mecha.model("pure/v1-5-pruned.safetensors") 9 | ) 10 | 11 | unet_recipe = sd_mecha.weighted_sum( 12 | sd_mecha.model("ghostmix_v20Bakedvae.safetensors"), 13 | sd_mecha.model("dreamshaper_332BakedVaeClipFix.safetensors"), 14 | ) 15 | 16 | recipe = sd_mecha.pick_component(unet_recipe, "diffuser") | text_encoder_recipe 17 | sd_mecha.serialize(recipe, output="recipes/test_split_unet_text_encoder.mecha") 18 | -------------------------------------------------------------------------------- /examples/rotate.py: -------------------------------------------------------------------------------- 1 | import sd_mecha 2 | sd_mecha.set_log_level() 3 | 4 | 5 | recipe = sd_mecha.rotate( 6 | sd_mecha.model("ghostmix_v20Bakedvae.safetensors"), 7 | sd_mecha.model("dreamshaper_332BakedVaeClipFix.safetensors"), 8 | ) 9 | 10 | sd = sd_mecha.merge(recipe) 11 | -------------------------------------------------------------------------------- /examples/split_unet_text_encoder.py: -------------------------------------------------------------------------------- 1 | import sd_mecha 2 | sd_mecha.set_log_level() 3 | 4 | 5 | text_encoder_recipe = sd_mecha.add_perpendicular( 6 | sd_mecha.model("js2prony_v10.safetensors"), 7 | sd_mecha.model("furry-xl-4.0.safetensors"), 8 | sd_mecha.model("pure/sdxl_base.safetensors"), 9 | ) 10 | 11 | unet_recipe = sd_mecha.weighted_sum( 12 | sd_mecha.model("js2prony_v10.safetensors"), 13 | sd_mecha.model("furry-xl-4.0.safetensors"), 14 | ) 15 | 16 | recipe = sd_mecha.pick_component(unet_recipe, "diffuser") | text_encoder_recipe 17 | sd = sd_mecha.merge(recipe) 18 | -------------------------------------------------------------------------------- /examples/ties_add_difference.py: -------------------------------------------------------------------------------- 1 | import sd_mecha 2 | sd_mecha.set_log_level() 3 | 4 | 5 | models = [ 6 | sd_mecha.model("ghostmix_v20Bakedvae.safetensors"), 7 | sd_mecha.model("dreamshaper_332BakedVaeClipFix.safetensors"), 8 | sd_mecha.model("realisticVisionV20_v20.safetensors"), 9 | sd_mecha.model("illustrationArtstyleMM_27.safetensors"), 10 | sd_mecha.model("lyriel_v16.safetensors"), 11 | sd_mecha.model("Midnight Maple.safetensors"), 12 | sd_mecha.model("mixproyuki77mi_v10.safetensors"), 13 | ] 14 | 15 | recipe = sd_mecha.add_difference_ties(sd_mecha.model("pure/v1-5-pruned.safetensors"), *models, alpha=0.5) 16 | sd = sd_mecha.merge(recipe, output_device="cpu") 17 | -------------------------------------------------------------------------------- /media/did-you-see-something.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljleb/sd-mecha/8150c2a16074cb8463751ddcddd8572d31830882/media/did-you-see-something.PNG -------------------------------------------------------------------------------- /media/memory-gone.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljleb/sd-mecha/8150c2a16074cb8463751ddcddd8572d31830882/media/memory-gone.PNG -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "sd-mecha" 7 | version = "1.0.4" 8 | description = "State dict recipe merger" 9 | readme = "README.md" 10 | authors = [{ name = "ljleb" }] 11 | requires-python = ">=3.10" 12 | classifiers = [ 13 | "License :: OSI Approved :: MIT License", 14 | "Programming Language :: Python :: 3", 15 | ] 16 | dependencies = [ 17 | "click>=8.1.7", 18 | "numpy>=1.26.0", 19 | "torch>=2.1.0", 20 | "scipy>=1.14.0", 21 | "safetensors>=0.4.2", 22 | "tqdm>=4.66.2", 23 | "fuzzywuzzy>=0.18.0", 24 | "python-Levenshtein", 25 | "PyYAML>=6.0.1", 26 | ] 27 | 28 | [tool.pytest.ini_options] 29 | minversion = "8.0" 30 | addopts = "-ra" 31 | pythonpath = "." 32 | testpaths = [ 33 | "tests", 34 | ] 35 | 36 | [tool.setuptools.package-dir] 37 | sd_mecha = "sd_mecha" 38 | 39 | [tool.setuptools.package-data] 40 | "*" = ["*.yaml"] 41 | 42 | [project.urls] 43 | Homepage = "https://github.com/ljleb/sd-mecha" 44 | Issues = "https://github.com/ljleb/sd-mecha/issues" 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click>=8.1.7 2 | numpy>=1.26.0 3 | scipy>=1.14.0 4 | torch>=2.1.0 5 | safetensors>=0.4.2 6 | tqdm>=4.66.2 7 | fuzzywuzzy>=0.18.0 8 | python-Levenshtein 9 | PyYAML>=6.0.1 10 | -------------------------------------------------------------------------------- /scripts/convert_ckpt_to_safetensors.py: -------------------------------------------------------------------------------- 1 | import click 2 | import pathlib 3 | import torch 4 | import safetensors.torch 5 | import safetensors 6 | 7 | 8 | @click.command() 9 | @click.option("-i", "--input", "input_path", type=pathlib.Path) 10 | @click.option("-o", "--output", "output_path", type=pathlib.Path) 11 | def main(input_path: pathlib.Path, output_path: pathlib.Path): 12 | if not output_path.suffix == ".safetensors": 13 | print("can only convert to .safetensors") 14 | exit(1) 15 | 16 | print("loading model...") 17 | ckpt = torch.load(input_path) 18 | ckpt = ckpt.get("state_dict", ckpt) 19 | if "state_dict" in ckpt: 20 | del ckpt["state_dict"] 21 | 22 | print("saving...") 23 | safetensors.torch.save_file(ckpt, output_path) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /scripts/dump_keys.py: -------------------------------------------------------------------------------- 1 | import safetensors 2 | import sys 3 | 4 | 5 | if __name__ == "__main__": 6 | with safetensors.safe_open(sys.argv[1], "pt") as f: 7 | for k in f.keys(): 8 | print(k) 9 | -------------------------------------------------------------------------------- /scripts/yaml_from_safetensors.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | from collections import OrderedDict 4 | from sd_mecha.extensions import model_configs 5 | from sd_mecha.extensions.model_configs import ModelConfigImpl, ModelComponent, KeyMetadata 6 | from sd_mecha.streaming import InSafetensorsDict 7 | 8 | 9 | def create_model_config(config_id: str, path_to_model: pathlib.Path): 10 | header = OrderedDict( 11 | (k, KeyMetadata(v.shape, v.dtype)) 12 | for k, v in InSafetensorsDict(path_to_model, 0).metadata() 13 | ) 14 | config = ModelConfigImpl(config_id, { 15 | "diffuser": ModelComponent(header) 16 | }) 17 | model_config_str = model_configs.to_yaml(config) 18 | path_to_config = pathlib.Path.cwd() / f"{config_id}.yaml" 19 | with open(path_to_config, "w") as f: 20 | f.write(model_config_str) 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser( 25 | description="Generate a model config YAML from a safetensors model file.", 26 | ) 27 | parser.add_argument( 28 | "config_id", 29 | help="Model config identifier, usually `-`.", 30 | ) 31 | parser.add_argument( 32 | "path_to_model", 33 | help="Path to the safetensors model file to be used as reference.", 34 | ) 35 | args = parser.parse_args() 36 | create_model_config(args.config_id, pathlib.Path(args.path_to_model)) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /sd_mecha/__init__.py: -------------------------------------------------------------------------------- 1 | from .merging import merge, open_input_dicts, infer_model_configs 2 | from .serialization import serialize, deserialize, deserialize_path 3 | from .streaming import StateDictKeyError 4 | from .extensions.merge_methods import merge_method, value_to_node, RecipeNodeOrValue, Parameter, Return, StateDict 5 | from .conversion import convert 6 | from sd_mecha.extensions.builtin.merge_methods import ( 7 | weighted_sum, 8 | slerp, 9 | n_average, 10 | geometric_median, 11 | subtract, 12 | perpendicular_component, 13 | geometric_sum, 14 | train_difference_mask, 15 | add_opposite_mask, 16 | add_strict_opposite_mask, 17 | add_cosine_a, 18 | add_cosine_b, 19 | ties_sum, 20 | ties_sum_extended, 21 | ties_sum_with_dropout, 22 | crossover, 23 | clamp, 24 | model_stock, 25 | fallback, 26 | cast, 27 | get_dtype, 28 | get_device, 29 | pick_component, 30 | omit_component, 31 | exchange_ema, 32 | ) 33 | from .merge_method_wrappers import ( 34 | add_difference, 35 | add_perpendicular, 36 | add_difference_ties, 37 | add_difference_ties_extended, 38 | copy_region, 39 | tensor_sum, 40 | rotate, 41 | dropout, 42 | ties_with_dare, 43 | n_model_stock, 44 | ) 45 | from .helpers import model, literal, Defaults, set_log_level 46 | from . import recipe_nodes, extensions 47 | 48 | 49 | def _load_builtin_extensions(): 50 | import sd_mecha.extensions.builtin.model_configs 51 | import sd_mecha.extensions.builtin.lycoris 52 | 53 | 54 | _load_builtin_extensions() 55 | -------------------------------------------------------------------------------- /sd_mecha/conversion.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import heapq 3 | import pathlib 4 | from typing import Dict, Tuple, Any, List, Iterable, Mapping 5 | from .extensions.merge_methods import value_to_node 6 | from .extensions.model_configs import ModelConfig 7 | from .extensions import merge_methods 8 | from .recipe_nodes import RecipeNode, RecipeNodeOrValue 9 | from sd_mecha.merging import open_input_dicts, infer_model_configs 10 | 11 | 12 | def convert(recipe: RecipeNodeOrValue, config: str | ModelConfig | RecipeNode, model_dirs: Iterable[pathlib.Path] = ()): 13 | """ 14 | Convert a recipe from one model config to another. 15 | 16 | This searches for a chain of registered conversion functions that transform `recipe`’s underlying 17 | config into the target config, then composes them. For example, you might need to 18 | convert a LoRA adapter into the base model’s format. 19 | 20 | Args: 21 | recipe: 22 | A `RecipeNode` or dictionary representing the input model or partial recipe. 23 | config (str, ModelConfig or RecipeNode): 24 | The desired output config, or a recipe node that has the desired config. 25 | model_dirs (Iterable[Path], optional): 26 | Directories to resolve relative model paths. 27 | 28 | Returns: 29 | A new recipe node describing the entire conversion. 30 | 31 | Raises: 32 | ValueError: 33 | If no conversion path is found. 34 | """ 35 | model_dirs = list(model_dirs) 36 | all_converters = merge_methods.get_all_converters() 37 | converter_paths: Dict[str, List[Tuple[str, Any]]] = {} 38 | for converter in all_converters: 39 | input_configs = converter.get_input_configs() 40 | return_config = converter.get_return_config(input_configs.args, input_configs.kwargs) 41 | src_config = input_configs.args[0].identifier 42 | tgt_config = return_config.identifier 43 | converter_paths.setdefault(src_config, []) 44 | converter_paths.setdefault(tgt_config, []) 45 | converter_paths[src_config].append((tgt_config, converter)) 46 | 47 | if isinstance(config, RecipeNode): 48 | with open_input_dicts(config, model_dirs): 49 | config = config.model_config 50 | 51 | tgt_config = config if isinstance(config, str) else config.identifier 52 | 53 | if isinstance(recipe, Mapping): 54 | possible_configs = infer_model_configs(recipe) 55 | for possible_config in (cfg for s in possible_configs for cfg in s): 56 | res = create_conversion_recipe(recipe, converter_paths, possible_config.identifier, tgt_config) 57 | if res is not None: 58 | return res 59 | raise ValueError( 60 | "could not infer the intended config to convert from. " 61 | "explicitly specifying the input config might resolve the issue" 62 | ) 63 | 64 | recipe = value_to_node(recipe) 65 | with open_input_dicts(recipe, model_dirs): 66 | src_config = recipe.model_config.identifier 67 | if src_config == "structural": 68 | raise ValueError( 69 | "recipe config is 'structural': " 70 | "structural recipes cannot be composed of any config conversions" 71 | ) 72 | res = create_conversion_recipe(recipe, converter_paths, src_config, tgt_config) 73 | if res is None: 74 | raise ValueError(f"no config conversion exists from {src_config} to {tgt_config}") 75 | return res 76 | 77 | 78 | def create_conversion_recipe(recipe, paths, src_config, tgt_config): 79 | shortest_path = dijkstra(paths, src_config, tgt_config) 80 | if shortest_path is None: 81 | return None 82 | return functools.reduce(lambda v, mm: mm(v), shortest_path, recipe) 83 | 84 | 85 | def dijkstra(graph, start, goal): 86 | """ 87 | graph: Dict[str, List[Tuple[str, any_id]]] 88 | For each node (str), a list of (neighbor_node, edge_id). 89 | start: str 90 | goal: str 91 | 92 | Returns: List of edge IDs (in order) that forms the shortest path from start to goal, 93 | or None if no path exists. 94 | """ 95 | 96 | distances = {node: float('inf') for node in graph} 97 | distances[start] = 0 98 | predecessors = {node: None for node in graph} # will store the node we came from 99 | edge_used = {node: None for node in graph} # will store which edge ID led here 100 | heap = [(0, start)] 101 | 102 | while heap: 103 | current_dist, current_node = heapq.heappop(heap) 104 | if current_dist > distances[current_node]: 105 | continue 106 | 107 | if current_node == goal: 108 | break 109 | 110 | for neighbor, edge_id in graph[current_node]: 111 | distance_via_current = current_dist + 1 112 | if distance_via_current < distances[neighbor]: 113 | distances[neighbor] = distance_via_current 114 | predecessors[neighbor] = current_node 115 | edge_used[neighbor] = edge_id 116 | heapq.heappush(heap, (distance_via_current, neighbor)) 117 | 118 | if distances[goal] == float('inf'): 119 | return None 120 | 121 | path_ids = [] 122 | node = goal 123 | while node != start: 124 | path_ids.append(edge_used[node]) 125 | node = predecessors[node] 126 | 127 | path_ids.reverse() 128 | return path_ids 129 | -------------------------------------------------------------------------------- /sd_mecha/extensions/__init__.py: -------------------------------------------------------------------------------- 1 | from . import merge_spaces 2 | from . import model_formats 3 | from . import model_configs 4 | from . import merge_methods 5 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/lycoris.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import torch 3 | from typing import Iterable, Mapping, Dict 4 | from sd_mecha.extensions import model_configs 5 | from sd_mecha.extensions.merge_methods import merge_method, StateDict, Parameter, Return 6 | from sd_mecha.extensions.model_configs import StateDictKey, ModelConfig, ModelConfigImpl, LazyModelConfigBase, KeyMetadata 7 | from sd_mecha.streaming import StateDictKeyError 8 | 9 | 10 | def _register_all_lycoris_configs(): 11 | base_configs = model_configs.get_all_base() 12 | for base_config in base_configs: 13 | for lyco_config in ( 14 | LycorisModelConfig(base_config, "lycoris", "lycoris", list(lycoris_algorithms)), 15 | LycorisModelConfig(base_config, "kohya", "lora", list(lycoris_algorithms)), 16 | ): 17 | model_configs.register_aux(lyco_config) 18 | lora_config_id = lyco_config.identifier 19 | base_config_id = lyco_config.base_config.identifier 20 | 21 | @merge_method(identifier=f"convert_'{lora_config_id}'_to_base", is_conversion=True) 22 | def diffusers_lora_to_base( 23 | lora: Parameter(StateDict[torch.Tensor], "weight", lora_config_id), 24 | **kwargs, 25 | ) -> Return(torch.Tensor, "delta", base_config_id): 26 | key = kwargs["key"] 27 | lycoris_keys = lyco_config.to_lycoris_keys(key) 28 | if not lycoris_keys: 29 | raise StateDictKeyError(key) 30 | 31 | lycoris_key = next(iter(lycoris_keys)) 32 | return compose_lora_up_down(lora, lycoris_key.split(".")[0]) 33 | 34 | 35 | def compose_lora_up_down(state_dict: Mapping[str, torch.Tensor], key: str): 36 | # fetching these 3 keys in lexicographic order is very important 37 | # any other order would raise the number of cache misses in the input safetensors when streaming keys 38 | # which in turn would slow down merging significantly 39 | alpha = state_dict[f"{key}.alpha"] 40 | down_weight = state_dict[f"{key}.lora_down.weight"] 41 | up_weight = state_dict[f"{key}.lora_up.weight"] 42 | dim = down_weight.size(0) 43 | 44 | if up_weight.numel() <= down_weight.numel(): 45 | up_weight = up_weight * (alpha / dim) 46 | else: 47 | down_weight = down_weight * (alpha / dim) 48 | 49 | if len(down_weight.size()) == 2: # linear 50 | res = up_weight @ down_weight 51 | elif down_weight.size()[2:4] == (1, 1): # conv2d 1x1 52 | res = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) 53 | else: # conv2d 3x3 54 | res = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) 55 | return res 56 | 57 | 58 | class LycorisModelConfig(LazyModelConfigBase): 59 | def __init__(self, base_config: ModelConfig, lycoris_identifier: str, prefix: str, algorithms: Iterable[str] | str): 60 | super().__init__() 61 | self.base_config = base_config 62 | self.lycoris_identifier = lycoris_identifier 63 | self.prefix = prefix 64 | self.algorithms = list(sorted(algorithms)) if not isinstance(algorithms, str) else [algorithms] 65 | 66 | @property 67 | def identifier(self) -> str: 68 | return f"{self.base_config.identifier}_{self.lycoris_identifier}_{'_'.join(self.algorithms)}" 69 | 70 | def create_config(self): 71 | if "lora" not in self.algorithms or len(self.algorithms) != 1: 72 | raise ValueError(f"unknown lycoris algorithms {self.algorithms}") 73 | 74 | identifier = f"{self.base_config.identifier}_{self.lycoris_identifier}_{'_'.join(self.algorithms)}" 75 | components = { 76 | k: _to_lycoris_keys(component.keys(), self.algorithms, self.prefix) 77 | for k, component in self.base_config.components().items() 78 | } 79 | return ModelConfigImpl(identifier, components) 80 | 81 | def to_lycoris_keys(self, key: StateDictKey) -> Mapping[StateDictKey, KeyMetadata]: 82 | return _to_lycoris_keys({key: KeyMetadata(None, None)}, self.algorithms, self.prefix) 83 | 84 | 85 | def _to_lycoris_keys( 86 | keys: Mapping[StateDictKey, KeyMetadata], 87 | algorithms: Iterable[str], 88 | prefix: str, 89 | ) -> Dict[StateDictKey, KeyMetadata]: 90 | lycoris_keys = {} 91 | 92 | for algorithm in algorithms: 93 | for key, meta in keys.items(): 94 | if key.endswith("bias") or not getattr(meta.metadata().dtype, "is_floating_point", True): 95 | continue 96 | 97 | key = key.split('.') 98 | if key[-1] == "weight": 99 | key = key[:-1] 100 | key = "_".join(key) 101 | 102 | for suffix in lycoris_algorithms[algorithm]: 103 | lycoris_key = f"{prefix}_{key}.{suffix}" 104 | lycoris_keys[lycoris_key] = dataclasses.replace(meta, shape=[], optional=True) 105 | 106 | return lycoris_keys 107 | 108 | 109 | lycoris_algorithms = { 110 | "lora": ("lora_up.weight", "lora_down.weight", "alpha"), 111 | } 112 | 113 | 114 | _register_all_lycoris_configs() 115 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/merge_methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .clamp import clamp 2 | from .cosine import add_cosine_a, add_cosine_b 3 | from .crossover import crossover 4 | from .ema import exchange_ema 5 | from .linear import ( 6 | weighted_sum, 7 | n_average, 8 | slerp, 9 | add_difference, 10 | subtract, 11 | perpendicular_component, 12 | train_difference_mask, 13 | add_opposite_mask, 14 | add_strict_opposite_mask, 15 | geometric_sum, 16 | multiply_quotient, 17 | ) 18 | from .logistics import ( 19 | fallback, 20 | cast, 21 | get_dtype, 22 | get_device, 23 | pick_component, 24 | omit_component, 25 | cast_dtype_map, 26 | cast_dtype_map_reversed, 27 | ) 28 | from .slicing import tensor_sum, top_k_tensor_sum 29 | from .svd import rotate, truncate_rank 30 | from .ties_sum import ( 31 | ties_sum_with_dropout, 32 | ties_sum, 33 | ties_sum_extended, 34 | model_stock, 35 | geometric_median, 36 | dropout, 37 | ) 38 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/merge_methods/clamp.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | from torch import Tensor 4 | from sd_mecha.extensions.merge_methods import merge_method, Parameter, Return 5 | 6 | 7 | @merge_method 8 | def clamp( 9 | a: Parameter(Tensor), 10 | *bounds: Parameter(Tensor), 11 | stiffness: Parameter(float) = 0.0, 12 | ) -> Return(Tensor): 13 | maximums = functools.reduce(torch.maximum, bounds) 14 | minimums = functools.reduce(torch.minimum, bounds) 15 | 16 | if stiffness: 17 | bounds = torch.stack(bounds) 18 | average = bounds.mean(dim=0) 19 | 20 | smallest_positive = maximums 21 | largest_negative = minimums 22 | 23 | for i, bound in enumerate(bounds): 24 | smallest_positive = torch.where((smallest_positive >= bound) & (bound >= average), bound, smallest_positive) 25 | largest_negative = torch.where((largest_negative <= bound) & (bound <= average), bound, largest_negative) 26 | 27 | maximums = (1-stiffness)*maximums + stiffness*smallest_positive 28 | minimums = (1-stiffness)*minimums + stiffness*largest_negative 29 | 30 | return torch.minimum(torch.maximum(a, minimums), maximums) 31 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/merge_methods/cosine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from sd_mecha.extensions.merge_methods import merge_method, Parameter, Return 4 | 5 | 6 | @merge_method 7 | def add_cosine_a( 8 | a: Parameter(Tensor, "weight"), 9 | b: Parameter(Tensor, "weight"), 10 | alpha: Parameter(Tensor) = 1.0, 11 | ) -> Return(Tensor, "weight"): 12 | a_norm = torch.nn.functional.normalize(a, dim=0) 13 | b_norm = torch.nn.functional.normalize(b, dim=0) 14 | similarity = torch.nn.functional.cosine_similarity(a_norm, b_norm, dim=0) 15 | return add_cosine_generic(a, b, alpha, similarity) 16 | 17 | 18 | @merge_method 19 | def add_cosine_b( 20 | a: Parameter(Tensor, "weight"), 21 | b: Parameter(Tensor, "weight"), 22 | alpha: Parameter(Tensor) = 1.0, 23 | ) -> Return(Tensor, "weight"): 24 | similarity = torch.nn.functional.cosine_similarity(a, b, dim=0) 25 | dot_product = torch.sum(a * b) 26 | magnitude_similarity = dot_product / (torch.norm(a) * torch.norm(b)) 27 | combined_similarity = (similarity + magnitude_similarity) / 2.0 28 | return add_cosine_generic(a, b, alpha, combined_similarity) 29 | 30 | 31 | def add_cosine_generic(a: Tensor, b: Tensor, alpha: Tensor, similarity: Tensor) -> Tensor: 32 | k = 1 - torch.clamp(similarity - alpha, 0, 1) 33 | return torch.lerp(a, b, k) 34 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/merge_methods/crossover.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from typing import Tuple 4 | from torch import Tensor 5 | from sd_mecha import merge_method, Parameter, Return 6 | 7 | 8 | @merge_method 9 | def crossover( 10 | a: Parameter(Tensor), 11 | b: Parameter(Tensor), 12 | alpha: Parameter(float) = 0.5, 13 | tilt: Parameter(float) = 0.0, 14 | ) -> Return(Tensor): 15 | if alpha == 0: 16 | return a 17 | if alpha == 1: 18 | return b 19 | if tilt == 1: 20 | return torch.lerp(a, b, alpha) 21 | 22 | if len(a.shape) == 0 or torch.allclose(a.half(), b.half()): 23 | return torch.lerp(a, b, tilt) 24 | 25 | shape = a.shape 26 | 27 | a_dft = torch.fft.rfftn(a, s=shape) 28 | b_dft = torch.fft.rfftn(b, s=shape) 29 | 30 | dft_filter = create_filter(a_dft.shape, alpha, tilt, device=a.device) 31 | 32 | x_dft = (1 - dft_filter)*a_dft + dft_filter*b_dft 33 | return torch.fft.irfftn(x_dft, s=shape) 34 | 35 | 36 | def create_filter(shape: Tuple[int, ...] | torch.Size, alpha: float, tilt: float, device=None): 37 | """ 38 | Create a crossover filter. The cut is first tilted around (0, 0), then slid along its normal until it touches the point (alpha, 1 - alpha). 39 | :param shape: shape of the filter 40 | :param alpha: the ratio between the low frequencies and high frequencies. must be in [0, 1] 41 | 0 = all 0s, 1 = all 1s, 0s correspond to low frequencies and 1s correspond to high frequencies 42 | :param tilt: tilt of the filter. 0 = vertical filter, 0.5 = 45 degrees, 1 = degenerates to a weighted sum with alpha=alpha 43 | :param device: device of the filter 44 | :return: 45 | """ 46 | if not 0 <= alpha <= 1: 47 | raise ValueError("alpha must be between 0 and 1") 48 | 49 | # normalize tilt to the range [0, 4] 50 | tilt -= math.floor(tilt // 4 * 4) 51 | if tilt > 2: 52 | alpha = 1 - alpha 53 | alpha_inverted = True 54 | else: 55 | alpha_inverted = False 56 | 57 | gradients = list(reversed([ 58 | torch.linspace(0, 1, s, device=device) 59 | if i == 0 or s == 1 else 60 | # negative frequencies are in the second half of the dimension 61 | torch.cat([ 62 | torch.linspace(0, (s - 1) // 2, s - s // 2, device=device), 63 | torch.linspace(s // 2, 1, s // 2, device=device) 64 | ]) / (s // 2) 65 | for i, s in enumerate(reversed(shape)) 66 | ])) 67 | 68 | if len(shape) > 1: 69 | grids = torch.meshgrid(*(g**2 for g in gradients), indexing='ij') 70 | mesh = (torch.stack(grids).sum(dim=0) / len(shape)).sqrt() 71 | else: 72 | mesh = gradients[0] 73 | 74 | if tilt < 1e-10 or abs(tilt - 4) < 1e-10: 75 | dft_filter = (mesh > 1 - alpha).float() 76 | elif abs(tilt - 2) < 1e-10: 77 | dft_filter = (mesh < 1 - alpha).float() 78 | else: 79 | tilt_cot = 1 / math.tan(math.pi * tilt / 2) 80 | if tilt <= 1 or 2 < tilt <= 3: 81 | dft_filter = mesh*tilt_cot + alpha*tilt_cot + alpha - tilt_cot 82 | else: # 1 < tilt <= 2 or 3 < tilt 83 | dft_filter = mesh*tilt_cot - alpha*tilt_cot + alpha 84 | dft_filter = dft_filter.clip(0, 1) 85 | 86 | if alpha_inverted: 87 | dft_filter = 1 - dft_filter 88 | return dft_filter 89 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/merge_methods/ema.py: -------------------------------------------------------------------------------- 1 | from sd_mecha import merge_method, Parameter, Return, StateDict 2 | from torch import Tensor 3 | 4 | 5 | @merge_method 6 | def exchange_ema( 7 | model: Parameter(StateDict[Tensor]), 8 | **kwargs, 9 | ) -> Return(Tensor): 10 | input_keys = model.model_config.keys() 11 | target_key = kwargs["key"] 12 | to_ema_key_fn = to_ema_key_fns.get(model.model_config.identifier, lambda k: k) 13 | ema_key = to_ema_key_fn(target_key) 14 | 15 | if ema_key in input_keys: 16 | return model[ema_key] 17 | else: 18 | for input_key in input_keys: 19 | if to_ema_key_fn(input_key) == target_key: 20 | return model[input_key] 21 | return model[target_key] 22 | 23 | 24 | to_ema_key_fns = { 25 | "sd1-ldm": lambda k: f"model_ema.{k[len('model.'):].replace('.', '')}" 26 | } 27 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/merge_methods/linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import Tensor 4 | from sd_mecha import merge_method, Parameter, Return, StateDict 5 | 6 | 7 | @merge_method 8 | def weighted_sum( 9 | a: Parameter(StateDict[Tensor]), 10 | b: Parameter(StateDict[Tensor]), 11 | alpha: Parameter(Tensor) = 0.5, 12 | **kwargs, 13 | ) -> Return(Tensor): 14 | key = kwargs["key"] 15 | 16 | if alpha.numel() == 1: 17 | alpha_float = alpha.item() 18 | if math.isclose(alpha_float, 0.0): 19 | return a[key] 20 | if math.isclose(alpha_float, 1.0): 21 | return b[key] 22 | 23 | return torch.lerp(a[key], b[key], alpha) 24 | 25 | 26 | @merge_method 27 | def n_average( 28 | *models: Parameter(StateDict[Tensor]), 29 | **kwargs, 30 | ) -> Return(Tensor): 31 | key = kwargs["key"] 32 | return sum(model[key] for model in models) / len(models) 33 | 34 | 35 | @merge_method 36 | def slerp( 37 | a: Parameter(Tensor), 38 | b: Parameter(Tensor), 39 | *, 40 | alpha: Parameter(Tensor) = 0.5, 41 | ) -> Return(Tensor): 42 | a_normalized = a / a.norm() 43 | b_normalized = b / b.norm() 44 | 45 | ab_dot = (a_normalized * b_normalized).sum().clamp(-1, 1) 46 | 47 | omega = torch.arccos(ab_dot) 48 | a_contrib = a_normalized * torch.sin((1-alpha)*omega) 49 | b_contrib = b_normalized * torch.sin(alpha*omega) 50 | res = (a_contrib + b_contrib) / torch.sin(omega) 51 | res *= torch.lerp(a.norm(), b.norm(), alpha) 52 | if res.isnan().any(): 53 | return torch.lerp(a, b, alpha) 54 | return res 55 | 56 | 57 | @merge_method 58 | def add_difference( 59 | a: Parameter(StateDict[Tensor], "weight"), 60 | b: Parameter(StateDict[Tensor], "delta"), 61 | alpha: Parameter(Tensor) = 1.0, 62 | **kwargs, 63 | ) -> Return(Tensor, "weight"): 64 | key = kwargs["key"] 65 | if alpha.numel() == 1 and math.isclose(alpha.item(), 0.0): 66 | return a[key] 67 | 68 | b_val = b[key] # try to load b from memory first in case it fails to merge before a 69 | return a[key].addcmul(b_val, alpha) 70 | 71 | 72 | @merge_method 73 | def subtract( 74 | a: Parameter(Tensor, "weight"), 75 | b: Parameter(Tensor, "weight"), 76 | ) -> Return(Tensor, "delta"): 77 | return a - b 78 | 79 | 80 | @merge_method 81 | def perpendicular_component( 82 | a: Parameter(Tensor), 83 | b: Parameter(Tensor), 84 | ) -> Return(Tensor): 85 | norm_a = torch.linalg.norm(a) 86 | res = b - a * (a / norm_a * (b / norm_a)).sum() 87 | if res.isnan().any(): 88 | return torch.zeros_like(a) 89 | return res 90 | 91 | 92 | @merge_method 93 | def train_difference_mask( 94 | a: Parameter(Tensor), 95 | b: Parameter(Tensor), 96 | c: Parameter(Tensor), 97 | alpha: Parameter(Tensor) = 1.0, 98 | ) -> Return(Tensor, "param"): 99 | return alpha * 1.8 * torch.nan_to_num((b - a).abs() / ((b - a).abs() + (b - c).abs()), nan=0) 100 | 101 | 102 | @merge_method 103 | def add_opposite_mask( 104 | a: Parameter(Tensor), 105 | b: Parameter(Tensor), 106 | c: Parameter(Tensor), 107 | alpha: Parameter(Tensor) = 1.0, 108 | ) -> Return(Tensor, "param"): 109 | return alpha * 2 * torch.nan_to_num((a - b).abs() / ((a - b).abs() + (a + b - 2*c).abs()), nan=0) 110 | 111 | 112 | @merge_method 113 | def add_strict_opposite_mask( 114 | a: Parameter(Tensor), 115 | b: Parameter(Tensor), 116 | c: Parameter(Tensor), 117 | alpha: Parameter(Tensor) = 1.0, 118 | ) -> Return(Tensor, "param"): 119 | threshold = torch.maximum(torch.abs(a - c), torch.abs(b - c)) 120 | return alpha * torch.clamp(torch.nan_to_num((c - a) * (b - c) / threshold**2, nan=0), 0) * 2 121 | 122 | 123 | @merge_method 124 | def geometric_sum( 125 | a: Parameter(Tensor), 126 | b: Parameter(Tensor), 127 | alpha: Parameter(Tensor) = 0.5, 128 | ) -> Return(Tensor): 129 | a = torch.complex(a, torch.zeros_like(a)) 130 | b = torch.complex(b, torch.zeros_like(b)) 131 | res = a ** (1 - alpha) * b ** alpha 132 | return res.real 133 | 134 | 135 | @merge_method 136 | def multiply_quotient( 137 | a: Parameter(Tensor), 138 | b: Parameter(Tensor), 139 | c: Parameter(Tensor), 140 | alpha: Parameter(Tensor) = 1.0, 141 | ) -> Return(Tensor): 142 | ac_log = torch.log(a.abs()) - torch.log(c.abs()) 143 | bc_log = torch.log(b.abs()) - torch.log(c.abs()) 144 | 145 | b = torch.complex(b, torch.zeros_like(b)) 146 | c = torch.complex(c, torch.zeros_like(c)) 147 | 148 | threshold = torch.maximum(torch.abs(ac_log), torch.abs(bc_log)) 149 | alpha = alpha * torch.clamp(-torch.nan_to_num(ac_log * bc_log / threshold**2, nan=0), 0) 150 | 151 | res = a * (b / c)**alpha 152 | res = torch.where(torch.isnan(res), a, res) 153 | return res.real 154 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/merge_methods/logistics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import TypeVar 4 | from sd_mecha.extensions.merge_methods import merge_method, StateDict, Parameter, Return 5 | from sd_mecha.streaming import StateDictKeyError 6 | 7 | 8 | T = TypeVar("T") 9 | 10 | 11 | @merge_method 12 | def fallback( 13 | a: Parameter(StateDict[T]), 14 | default: Parameter(StateDict[T]), 15 | **kwargs, 16 | ) -> Return(T): 17 | key = kwargs["key"] 18 | try: 19 | return a[key] 20 | except StateDictKeyError: 21 | return default[key] 22 | 23 | 24 | @merge_method 25 | def cast( 26 | a: Parameter(Tensor), 27 | device: Parameter(str) = None, 28 | dtype: Parameter(str) = None, 29 | ) -> Return(Tensor): 30 | to_kwargs = {} 31 | if device is not None: 32 | to_kwargs["device"] = device 33 | 34 | if dtype is not None: 35 | if dtype not in cast_dtype_map: 36 | raise ValueError(f"Unknown dtype {dtype}. Possible values are None, {', '.join(cast_dtype_map.keys())}") 37 | to_kwargs["dtype"] = cast_dtype_map[dtype] 38 | 39 | return a.to(**to_kwargs) 40 | 41 | 42 | cast_dtype_map = { 43 | "float64": torch.float64, 44 | "int64": torch.int64, 45 | "float32": torch.float32, 46 | "int32": torch.int32, 47 | "float16": torch.float16, 48 | "bfloat16": torch.bfloat16, 49 | "int16": torch.int16, 50 | "float8_e4m3fn": torch.float8_e4m3fn, 51 | "float8_e5m2": torch.float8_e5m2, 52 | "int8": torch.int8, 53 | "bool": torch.bool, 54 | } 55 | for dtype_str in ("uint8", "uint16", "uint32", "uint64"): 56 | if hasattr(torch, dtype_str): 57 | cast_dtype_map[dtype_str] = getattr(torch, dtype_str) 58 | cast_dtype_map_reversed = {v: k for k, v in cast_dtype_map.items()} 59 | 60 | 61 | @merge_method 62 | def get_dtype( 63 | a: Parameter(Tensor), 64 | ) -> Return(str, "param"): 65 | return cast_dtype_map_reversed[a.dtype] 66 | 67 | 68 | @merge_method 69 | def get_device( 70 | a: Parameter(Tensor), 71 | ) -> Return(str, "param"): 72 | return str(a.device) 73 | 74 | 75 | @merge_method 76 | def pick_component( 77 | a: Parameter(StateDict[T]), 78 | component: Parameter(str, "param"), 79 | **kwargs, 80 | ) -> Return(T): 81 | if component not in a.model_config.components(): 82 | raise ValueError( 83 | f'Component "{component}" does not exist in config "{a.model_config.identifier}". ' 84 | f"Valid components: {tuple(a.model_config.components())}" 85 | ) 86 | 87 | key = kwargs["key"] 88 | if key in a.model_config.components()[component].keys(): 89 | return a[key] 90 | else: 91 | raise StateDictKeyError(key) 92 | 93 | 94 | @merge_method 95 | def omit_component( 96 | a: Parameter(StateDict[T]), 97 | component: Parameter(str, "param"), 98 | **kwargs, 99 | ) -> Return(T): 100 | if component not in a.model_config.components(): 101 | raise ValueError( 102 | f'Component "{component}" does not exist in config "{a.model_config.identifier}". ' 103 | f"Valid components: {tuple(a.model_config.components())}" 104 | ) 105 | 106 | key = kwargs["key"] 107 | if key in a.model_config.components()[component].keys(): 108 | raise StateDictKeyError(key) 109 | else: 110 | return a[key] 111 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/merge_methods/slicing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from typing import Tuple 4 | from torch import Tensor 5 | from sd_mecha.extensions.merge_methods import merge_method, Parameter, Return 6 | 7 | 8 | @merge_method 9 | def tensor_sum( 10 | a: Parameter(Tensor), 11 | b: Parameter(Tensor), 12 | width: Parameter(float) = 0.5, 13 | offset: Parameter(float) = 0.0, 14 | ) -> Return(Tensor): 15 | if a.shape == (): 16 | if width > 0.5: 17 | return b 18 | return a 19 | 20 | start_i, end_i, region_is_inverted = ratio_to_region(width, offset, a.size(0)) 21 | if region_is_inverted: 22 | b[start_i:end_i] = a[start_i:end_i] 23 | return b 24 | else: 25 | a[start_i:end_i] = b[start_i:end_i] 26 | return a 27 | 28 | 29 | @merge_method 30 | def top_k_tensor_sum( 31 | a: Parameter(Tensor), 32 | b: Parameter(Tensor), 33 | width: Parameter(float) = 1.0, 34 | offset: Parameter(float) = 0.0, 35 | ) -> Return(Tensor): 36 | a_flat = torch.flatten(a) 37 | a_dist = torch.msort(a_flat) 38 | b_indices = torch.argsort(torch.flatten(b), stable=True) 39 | redist_indices = torch.argsort(b_indices) 40 | 41 | start_i, end_i, region_is_inverted = ratio_to_region(width, offset, torch.numel(a)) 42 | start_top_k = kth_abs_value(a_dist, start_i) 43 | end_top_k = kth_abs_value(a_dist, end_i) 44 | 45 | indices_mask = (start_top_k <= torch.abs(a_dist)) & (torch.abs(a_dist) <= end_top_k) 46 | if region_is_inverted: 47 | indices_mask = ~indices_mask 48 | indices_mask = torch.gather(indices_mask.float(), 0, redist_indices) 49 | 50 | a_redist = torch.gather(a_dist, 0, redist_indices) 51 | a_redist = (1 - indices_mask) * a_flat + indices_mask * a_redist 52 | return a_redist.reshape_as(a) 53 | 54 | 55 | def kth_abs_value(a: Tensor, k: int) -> Tensor: 56 | if k <= 0: 57 | return torch.tensor(-1, device=a.device) 58 | else: 59 | return torch.kthvalue(torch.abs(a.float()), k)[0] 60 | 61 | 62 | def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]: 63 | if width < 0: 64 | offset += width 65 | width = -width 66 | width = min(width, 1) 67 | 68 | if offset < 0: 69 | offset = 1 + offset - int(offset) 70 | offset = math.fmod(offset, 1.0) 71 | 72 | if width + offset <= 1: 73 | inverted = False 74 | start = offset * n 75 | end = (width + offset) * n 76 | else: 77 | inverted = True 78 | start = (width + offset - 1) * n 79 | end = offset * n 80 | 81 | return round(start), round(end), inverted 82 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/merge_methods/svd.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import torch 4 | from typing import Optional, Tuple 5 | from torch import Tensor 6 | from sd_mecha import merge_method, Parameter, StateDict, Return 7 | 8 | 9 | @merge_method 10 | def rotate( 11 | a: Parameter(StateDict[Tensor]), 12 | b: Parameter(StateDict[Tensor]), 13 | alignment: Parameter(float) = 1.0, 14 | alpha: Parameter(Tensor) = 0.0, 15 | centralization: Parameter(float) = 1.0, 16 | stiefel_eps: Parameter(float) = 1e-8, 17 | stiefel_max_iters: Parameter(int) = 100, 18 | **kwargs, 19 | ) -> Return(Tensor): 20 | """ 21 | Align model A with model B by an orthogonal transform. 22 | 23 | Useful properties: alignment=alpha=0 returns model A, whereas alignment=alpha=1 returns model B. 24 | 25 | :param a: model A 26 | :param b: model B 27 | :param alignment: decides how much to align a to b by an orthogonal matrix. fractional values between 0 and 1 allow to rotate model A towards model B partially. An alignment of 1 means to minimize the distance between the two by rotation alone (no scaling allowed) 28 | :param alpha: interpolates the scaling component of model A with model B's. This interpolates the part of model A that is not affected by the orthogonal martix 29 | :param centralization: how much to center the rows of model A and model B before applying the alignment. centering the rows allows to align model A to model B a lot more closely 30 | :param stiefel_eps: acceptable error for wide matrices with fractional alignment 31 | :param stiefel_max_iters: maximum number of iterations for wide matrices with fractional alignment 32 | :return: model A rotated towards B by an orthogonal transform Q^alignment, after centralizing model A. 33 | """ 34 | key = kwargs["key"] 35 | 36 | if alpha.numel() == 1: 37 | alpha_float = alpha.item() 38 | if math.isclose(alignment, 0.0) and math.isclose(alpha_float, 0.0): 39 | return a[key] 40 | if math.isclose(alignment, 1.0) and math.isclose(alpha_float, 1.0): 41 | return b[key] 42 | 43 | a = a[key] 44 | b = b[key] 45 | if len(a.shape) <= 1 or torch.allclose(a.half(), b.half()): 46 | return torch.lerp(a, b, alpha) 47 | 48 | is_conv = len(a.shape) == 4 and a.shape[-2:].numel() != 1 49 | if is_conv: 50 | shape_2d = a.shape[:2].numel(), a.shape[2:].numel() 51 | else: 52 | shape_2d = a.shape[:1].numel(), a.shape[1:].numel() 53 | 54 | cache = kwargs.get("cache") 55 | if cache is not None: 56 | key = kwargs["key"] 57 | if key not in cache: 58 | cache[key] = {} 59 | cache = cache[key] 60 | 61 | if cache is not None: 62 | # if centralization is different from the cached value, invalidate cache 63 | if not math.isclose(cache.get("centralization", centralization), centralization): 64 | cache.clear() 65 | else: 66 | cache["centralization"] = centralization 67 | 68 | a_neurons = a.reshape(*shape_2d) 69 | b_neurons = b.reshape(*shape_2d) 70 | a_centroid = a_neurons.mean(0) * centralization 71 | b_centroid = b_neurons.mean(0) * centralization 72 | a_neurons -= a_centroid 73 | b_neurons -= b_centroid 74 | 75 | alignment_is_float = not math.isclose(alignment, round(alignment)) 76 | 77 | if cache is not None and "transform" in cache: 78 | transform = cache["transform"].to(device=a.device, dtype=a.dtype) 79 | else: 80 | transform = orthogonal_procrustes(a_neurons, b_neurons, cancel_reflection=alignment_is_float) 81 | if cache is not None: 82 | cache["transform"] = transform.to(device="cpu", dtype=torch.bfloat16) 83 | 84 | if alpha.numel() > 1 or not math.isclose(alpha.item(), 0): 85 | a_neurons = torch.lerp(a_neurons, transform(b_neurons, -1, cache, key), alpha) 86 | 87 | a_neurons = transform(a_neurons, alignment, cache, key, stiefel_eps=stiefel_eps, stiefel_max_iters=stiefel_max_iters) 88 | a_neurons += torch.lerp(a_centroid, b_centroid, alignment) 89 | return a_neurons.reshape_as(a) 90 | 91 | 92 | @merge_method 93 | def truncate_rank( 94 | a: Parameter(Tensor, merge_space="delta"), 95 | rank_ratio: Parameter(float) = 0.5, 96 | use_approximate_basis: Parameter(bool) = True, 97 | approximate_basis_iters: Parameter(int) = 2, 98 | approximate_basis_seed: Parameter(int) = None, 99 | **kwargs, 100 | ) -> Return(Tensor, merge_space="delta"): 101 | if a.dim() < 2: 102 | return a 103 | 104 | cache = kwargs.get("cache") 105 | if cache is not None: 106 | key = kwargs["key"] 107 | if key not in cache: 108 | cache[key] = {} 109 | cache = cache[key] 110 | 111 | a_2d = a.flatten(start_dim=1) 112 | max_rank = min(a_2d.shape) 113 | target_rank = min(max(round(max_rank * rank_ratio), 0), max_rank) 114 | if target_rank == max_rank: 115 | return a 116 | if target_rank == 0: 117 | return torch.zeros_like(a) 118 | 119 | original_shape = a.shape 120 | if ( 121 | "s" in cache and cache["s"].numel() >= target_rank and 122 | cache.get("iters", approximate_basis_iters) == approximate_basis_iters and 123 | cache.get("seed", approximate_basis_seed) == approximate_basis_seed 124 | ): 125 | u = cache["u"][..., :target_rank].to(a) 126 | s = cache["s"][..., :target_rank].to(a) 127 | vh = cache["vh"][..., :target_rank, :].to(a) 128 | else: 129 | svd_driver = "gesvda" if a.is_cuda else None 130 | if use_approximate_basis: 131 | u, s, vh = svd_lowrank(a_2d, rank=target_rank, iters=approximate_basis_iters, seed=approximate_basis_seed, driver=svd_driver) 132 | else: 133 | u, s, vh = torch.linalg.svd(a_2d, full_matrices=False, driver=svd_driver) 134 | if cache is not None: 135 | cache["u"] = u.to(device="cpu", dtype=torch.bfloat16) 136 | cache["s"] = s.to(device="cpu", dtype=torch.bfloat16) 137 | cache["vh"] = vh.to(device="cpu", dtype=torch.bfloat16) 138 | if use_approximate_basis: 139 | cache["iters"] = approximate_basis_iters 140 | cache["seed"] = approximate_basis_seed 141 | else: 142 | cache.pop("iters", None) 143 | cache.pop("seed", None) 144 | 145 | return (u[..., :target_rank] * s[..., :target_rank].unsqueeze(-2) @ vh[..., :target_rank, :]).reshape(original_shape) 146 | 147 | 148 | def orthogonal_procrustes(a, b, cancel_reflection: bool = False): 149 | n, p = a.shape[-2:] 150 | if n < p: 151 | svd_driver = "gesvd" if a.is_cuda else None 152 | u, _, vh = svd_lowrank(a.mH @ b, rank=a.shape[0], driver=svd_driver) 153 | return LowRankOrthogonalMatmul(u, vh) 154 | else: 155 | svd_driver = "gesvd" if a.is_cuda else None 156 | u, _, vh = torch.linalg.svd(a.mH @ b, driver=svd_driver) 157 | if cancel_reflection: 158 | u[..., -1] /= torch.slogdet(u @ vh)[0] 159 | 160 | return FullRankOrthogonalMatmul(u @ vh) 161 | 162 | 163 | class LowRankOrthogonalMatmul: 164 | def __init__(self, u, vh): 165 | self.u = u 166 | self.vh = vh 167 | 168 | def __call__(self, x: Tensor, t: float | int = 1.0, cache: Optional[dict] = None, key: Optional[str] = None, stiefel_eps=1e-8, stiefel_max_iters=100, **_kwargs): 169 | def x_proj(): return x - x @ self.vh.mH @ self.vh 170 | 171 | if math.isclose(t, 0.0): 172 | return x 173 | elif math.isclose(t, 1.0): 174 | return x_proj() + x @ self.u @ self.vh 175 | elif math.isclose(t, -1.0): 176 | return x_proj() + x @ self.vh.mH @ self.u.mH 177 | elif math.isclose(t, round(t)): 178 | if t > 0: 179 | return x_proj() + x @ self.u @ torch.linalg.matrix_power(self.vh @ self.u, round(t) - 1) @ self.vh 180 | else: 181 | return x_proj() + x @ self.vh.mH @ torch.linalg.matrix_power(self.u.mH @ self.vh.mH, abs(round(t)) - 1) @ self.u.mH 182 | else: 183 | u = stiefel_interpolate(self.vh.mH, self.u, t, stiefel_eps, stiefel_max_iters, cache, key) 184 | return x_proj() + x @ u @ self.vh 185 | 186 | def to(self, *args, **kwargs): 187 | return LowRankOrthogonalMatmul(self.u.to(*args, **kwargs), self.vh.to(*args, **kwargs)) 188 | 189 | 190 | class FullRankOrthogonalMatmul: 191 | def __init__(self, rotation): 192 | self.rotation = rotation 193 | 194 | def __call__(self, x: Tensor, t: float | int = 1.0, cache: Optional[dict] = None, key: Optional[str] = None, **_kwargs): 195 | if math.isclose(t, 0.0): 196 | return x 197 | 198 | transform = fractional_orthogonal_matrix_power(self.rotation, t, cache, key) 199 | return x @ transform 200 | 201 | def to(self, *args, **kwargs): 202 | return FullRankOrthogonalMatmul(self.rotation.to(*args, **kwargs)) 203 | 204 | 205 | def fractional_orthogonal_matrix_power(q, t, cache=None, key=None): 206 | if math.isclose(t, 0.0): 207 | return torch.eye(q.shape[-1], device=q.device, dtype=q.dtype) 208 | elif math.isclose(t, 1.0): 209 | return q 210 | elif math.isclose(t, -1.0): 211 | return q.mH 212 | elif math.isclose(t, round(t)): 213 | return torch.linalg.matrix_power(q, round(t)) 214 | else: 215 | return orthogonal_matrix_power(q, t, cache, key) 216 | 217 | 218 | def orthogonal_matrix_power(q, power, cache=None, key=None): 219 | if cache is not None and "eig_v" in cache: 220 | eig_v = torch.view_as_complex(cache["eig_v"].to(device=q.device, dtype=q.dtype)) 221 | eig_vs = torch.view_as_complex(cache["eig_vs"].to(device=q.device, dtype=q.dtype)) 222 | else: 223 | eig_v, eig_vs = torch.linalg.eig(q) 224 | if cache is not None: 225 | cache["eig_v"] = torch.view_as_real(eig_v).to(device="cpu", dtype=torch.bfloat16) 226 | cache["eig_vs"] = torch.view_as_real(eig_vs).to(device="cpu", dtype=torch.bfloat16) 227 | 228 | eig_v_pow = eig_v**power 229 | result = eig_vs * eig_v_pow.unsqueeze(-2) @ eig_vs.mH 230 | if result.imag.abs().max() > 1e-6: 231 | logging.warning(f"imaginary residual in fractional matrix power: max|Im Q^p| = {result.imag.abs().max().item()}, key: {key}") 232 | return result.to(dtype=q.dtype) 233 | 234 | 235 | # src: https://github.com/pytorch/pytorch/blob/f714599c57b3854460002335df7d67af98f12176/torch/_lowrank.py#L150 236 | # license applies, see /LICENSE-pytorch.txt 237 | def svd_lowrank(a: Tensor, rank: int, iters: int = 0, seed: int = None, driver: Optional[str] = None) -> Tuple[Tensor, Tensor, Tensor]: 238 | m, n = a.shape[-2:] 239 | 240 | if m < n: 241 | a = a.mH 242 | 243 | q = get_approximate_basis(a, rank, iters=iters, seed=seed) 244 | b = q.mH @ a 245 | u, s, vh = torch.linalg.svd(b, full_matrices=False, driver=driver) 246 | u = q @ u 247 | 248 | if m < n: 249 | u, vh = vh.mH, u.mH 250 | 251 | return u, s, vh 252 | 253 | 254 | # src: https://github.com/pytorch/pytorch/blob/f714599c57b3854460002335df7d67af98f12176/torch/_lowrank.py#L12 255 | # license applies, see /LICENSE-pytorch.txt 256 | def get_approximate_basis(a: Tensor, rank: int, iters: int = 0, seed: int = None) -> Tensor: 257 | generator = None 258 | if seed is not None: 259 | generator = torch.Generator(a.device) 260 | generator.manual_seed(seed) 261 | 262 | r = torch.randn(a.shape[-1], rank, dtype=a.dtype, device=a.device, generator=generator) 263 | q = torch.linalg.householder_product(*torch.geqrf(a @ r)) 264 | for i in range(iters): 265 | q = torch.linalg.householder_product(*torch.geqrf(a.mH @ q)) 266 | q = torch.linalg.householder_product(*torch.geqrf(a @ q)) 267 | return q 268 | 269 | 270 | def orthogonal_complete(a: Tensor) -> Tensor: 271 | m, n = a.shape[-2:] 272 | if m <= n: 273 | return a 274 | 275 | proj = torch.eye(m, device=a.device, dtype=a.dtype)[:, n:] - a @ a.mH[..., n:] 276 | a_extension = torch.linalg.householder_product(*torch.geqrf(proj)) 277 | return torch.cat((a, a_extension), dim=-1) 278 | 279 | 280 | def stiefel_interpolate(a, b, t, eps=1e-8, max_iters=100, cache=None, key=None): 281 | delta = log_stiefel(a, b, eps, max_iters, cache, key) 282 | res = exp_stiefel(a, t * delta) 283 | return res 284 | 285 | 286 | def exp_stiefel(u, delta): 287 | n, p = u.shape[-2:] 288 | 289 | assert n > p, "u should be tall, not square nor wide" 290 | k = min(n-p, p) 291 | 292 | a = u.mH @ delta 293 | q, r = qr_pos(delta - u @ a) 294 | q = q[..., :k] 295 | r = r[..., :k, :] 296 | w = torch.cat(( 297 | torch.cat((a, -r.mH), -1), 298 | torch.cat((r, torch.zeros_like(a[..., :k, :k])), -1) 299 | ), -2) 300 | m = torch.linalg.matrix_exp(w) 301 | res = u @ m[..., :p, :p] + q @ m[..., p:, :p] 302 | return res 303 | 304 | 305 | def log_stiefel(a, b, eps=1e-8, max_iters=100, cache=None, key=None): 306 | if ( 307 | cache is not None and "log_stiefel" in cache and 308 | math.isclose(math.log10(cache.get("log_stiefel_eps", eps)), math.log10(eps)) and ( 309 | cache["log_stiefel_converged"] and cache["log_stiefel_iters"] < max_iters or 310 | 311 | # possible optimization: start from current cache["log_stiefel"] when max_iters > cache["log_stiefel_iters"] 312 | not cache["log_stiefel_converged"] and cache["log_stiefel_iters"] == max_iters 313 | ) 314 | ): 315 | return cache["log_stiefel"].to(device=a.device, dtype=a.dtype) 316 | 317 | original_shape = a.shape 318 | a = a.view(-1, original_shape[-2], original_shape[-1]) 319 | b = b.view(-1, original_shape[-2], original_shape[-1]) 320 | batch_size, n, p = b.shape 321 | assert n > p, "a and b should be tall, not square nor wide" 322 | k = min(n-p, p) 323 | assert max_iters >= 1, "max_iters should be at least 1" 324 | 325 | m = a.mH @ b 326 | 327 | q, n_mat = qr_pos(b - a @ m) 328 | q = q[..., :k] 329 | n_mat = n_mat[..., :k, :] 330 | v = orthogonal_complete(torch.cat((m, n_mat), dim=-2)) 331 | 332 | r, sigma, r_hat_t = torch.linalg.svd(v[..., p:, p:], driver="gesvd" if v.is_cuda else None) 333 | q @= r 334 | v[..., p:, :p] = r.mH @ n_mat 335 | v[..., :p, p:] @= r_hat_t.mH 336 | p_arange = torch.arange(p, p+k, device=v.device) 337 | v[..., p:, p:].zero_() 338 | v[..., p_arange, p_arange] = sigma 339 | del r, sigma, r_hat_t, p_arange 340 | 341 | v[v.slogdet()[0] < 0, ..., -1] *= -1 342 | 343 | k_arange = torch.arange(k, device=v.device, dtype=torch.long) 344 | printed_error = False 345 | l = None 346 | converged = False 347 | for i in range(max_iters): 348 | l = logm(v, key) 349 | c = l[..., p:, p:] 350 | c_norm_idx = torch.linalg.matrix_norm(c).argmax() 351 | c_norm = torch.linalg.matrix_norm(c[c_norm_idx]) 352 | if c_norm > 10: 353 | logging.warning(f"log_stiefel: very high c_norm={c_norm.item():0.3f} at iteration {i}, batch {c_norm_idx}, key: {key}") 354 | printed_error = True 355 | elif printed_error: 356 | logging.warning(f"log_stiefel: started converging c_norm={c_norm.item():0.3f} at iteration {i}, batch {c_norm_idx}, key: {key}") 357 | printed_error = False 358 | if c_norm <= eps: 359 | converged = True 360 | break 361 | elif i == max_iters - 1: 362 | logging.warning(f"log_stiefel: {c_norm.item()}, batch {c_norm_idx}, key: {key}") 363 | 364 | s = l[..., p:, :p] @ l[..., p:, :p].mH / 12 365 | s[..., k_arange, k_arange] -= 0.5 366 | g = solve_symmetric_sylvester(s, c) 367 | v[..., p:] @= torch.linalg.matrix_exp(g) 368 | 369 | delta = a @ l[..., :p, :p] + q @ l[..., p:, :p] 370 | res = delta.reshape(original_shape) 371 | 372 | if cache is not None: 373 | cache["log_stiefel"] = res.to(device="cpu", dtype=torch.bfloat16) 374 | cache["log_stiefel_eps"] = eps 375 | cache["log_stiefel_iters"] = i + 1 376 | cache["log_stiefel_converged"] = converged 377 | 378 | return res 379 | 380 | 381 | def solve_symmetric_sylvester(s, c): 382 | v, vs = torch.linalg.eigh(s) 383 | c_t = vs.mH @ c @ vs 384 | d = v.unsqueeze(-2) + v.unsqueeze(-1) 385 | if torch.any(torch.abs(d) < 1e-12): 386 | logging.warning("Singular Sylvester operator: some λ_i+λ_j ≈ 0") 387 | 388 | g_t = c_t / d 389 | g = vs @ g_t @ vs.mH 390 | return g 391 | 392 | 393 | def logm(m, key): 394 | v, vs = torch.linalg.eig(m) 395 | v_log = v.unsqueeze(-2).log() 396 | res = torch.linalg.solve(vs, vs*v_log, left=False) 397 | 398 | max_v, _ = res.imag.abs().flatten(start_dim=-2).max(dim=-1) 399 | if max_v[max_v.argmax()] > 1e-4: 400 | logging.warning(f"imaginary residual at batch index {max_v.argmax()}: {max_v[max_v.argmax()].item()}, key: {key}") 401 | return res.to(m.dtype) 402 | 403 | 404 | def qr_pos(x): 405 | q, r = torch.linalg.qr(x) 406 | s = torch.sign(r.diagonal(offset=0, dim1=-2, dim2=-1)) 407 | s[s == 0] = 1 408 | return q * s.unsqueeze(-2), r / s.unsqueeze(-1) 409 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/merge_methods/ties_sum.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from sd_mecha.extensions.merge_methods import merge_method, Parameter, Return 5 | from typing import Tuple, Sequence, Optional 6 | from scipy.stats import binom, rankdata 7 | from torch import Tensor 8 | 9 | 10 | @merge_method 11 | def ties_sum_with_dropout( 12 | *deltas: Parameter(Tensor, "delta"), 13 | probability: Parameter(Tensor) = 0.9, 14 | della_eps: Parameter(float) = 0.0, 15 | rescale: Parameter(bool) = True, 16 | k: Parameter(float) = 0.2, 17 | vote_sgn: Parameter(bool) = False, 18 | apply_stock: Parameter(bool) = False, 19 | cos_eps: Parameter(float) = 1e-6, 20 | apply_median: Parameter(bool) = False, 21 | eps: Parameter(float) = 1e-6, 22 | maxiter: Parameter(int) = 100, 23 | ftol: Parameter(float) = 1e-20, 24 | seed: Parameter(int) = None, 25 | ) -> Return(Tensor, "delta"): 26 | if not deltas or math.isclose(probability, 1.0): 27 | return 0 28 | 29 | generator = torch.Generator(deltas[0].device) 30 | if seed is not None and seed >= 0: 31 | generator.manual_seed(seed) 32 | 33 | deltas = [delta * find_della_dropout(delta, probability, della_eps, generator) for delta in deltas] 34 | deltas = ties_sum_extended.__wrapped__( 35 | *deltas, 36 | k=k, 37 | vote_sgn=vote_sgn, 38 | apply_stock=apply_stock, 39 | cos_eps=cos_eps, 40 | apply_median=apply_median, 41 | eps=eps, 42 | maxiter=maxiter, 43 | ftol=ftol, 44 | ) 45 | 46 | if math.isclose(probability, 1.0) or not rescale: 47 | rescalar = 1.0 48 | else: 49 | rescalar = 1.0 - probability 50 | return deltas / rescalar 51 | 52 | 53 | def find_della_dropout(delta: Tensor, probability: Tensor, della_eps: float, generator: torch.Generator): 54 | if not math.isclose(della_eps, 0.0): 55 | rank_per_element = torch.from_numpy(rankdata(delta.abs().numpy(force=True), method="ordinal").reshape(delta.shape)).to(device=delta.device) 56 | ne = delta.numel() 57 | # center window 58 | delta_i = (rank_per_element / ne - ((ne + 1) / (ne * 2))) * della_eps 59 | else: 60 | delta_i = 0.0 61 | 62 | p_min = torch.ones_like(delta) - probability 63 | res = torch.bernoulli(torch.clamp(p_min + delta_i, min=1e-20, max=1), generator=generator) 64 | return res 65 | 66 | 67 | @merge_method 68 | def ties_sum_extended( 69 | *models: Parameter(Tensor, "delta"), 70 | k: Parameter(float) = 0.2, 71 | vote_sgn: Parameter(bool) = False, 72 | apply_stock: Parameter(bool) = False, 73 | apply_median: Parameter(bool) = False, 74 | cos_eps: Parameter(float) = 1e-6, 75 | eps: Parameter(float) = 1e-6, 76 | maxiter: Parameter(int) = 100, 77 | ftol: Parameter(float) = 1e-20, 78 | ) -> Return(Tensor, "delta"): 79 | filtered_delta, param_counts = ties_sum_deltas(*models, k=k, vote_sgn=vote_sgn) 80 | 81 | if apply_median: 82 | filtered_delta = geometric_median.__wrapped__(*filtered_delta, eps=eps, maxiter=maxiter, ftol=ftol) 83 | else: 84 | t = 1.0 if apply_stock else get_model_stock_t(filtered_delta, cos_eps=cos_eps) 85 | filtered_delta = filtered_delta.sum(dim=0) 86 | filtered_delta = filtered_delta * t / param_counts 87 | 88 | return torch.nan_to_num(filtered_delta) 89 | 90 | 91 | # src: https://arxiv.org/abs/2306.01708 92 | @merge_method 93 | def ties_sum( 94 | *models: Parameter(Tensor, "delta"), 95 | k: Parameter(float) = 1.0, 96 | vote_sgn: Parameter(bool) = False, 97 | ) -> Return(Tensor, "delta"): 98 | filtered_delta, param_counts = ties_sum_deltas(*models, k=k, vote_sgn=vote_sgn) 99 | return torch.nan_to_num(filtered_delta.sum(dim=0) / param_counts) 100 | 101 | 102 | def ties_sum_deltas( 103 | *models: Tensor, 104 | k: float = 0.2, 105 | vote_sgn: bool = False, 106 | ): 107 | deltas = torch.stack([filter_top_k(m, k) for m in models], dim=0) 108 | signs = torch.sign(deltas) 109 | final_sign = torch.sign(torch.sum(deltas if vote_sgn else signs, dim=0)) 110 | 111 | delta_filters = (signs == final_sign).float() 112 | filtered_delta = deltas * delta_filters 113 | param_counts = torch.sum(delta_filters, dim=0) 114 | return filtered_delta, param_counts 115 | 116 | 117 | def filter_top_k(a: Tensor, k: float): 118 | k = max(int((1 - k) * a.numel()), 1) 119 | k_value, _ = a.flatten().abs().float().kthvalue(k) 120 | top_k_filter = (a.abs() >= k_value).float() 121 | return a * top_k_filter 122 | 123 | 124 | # src: https://github.com/arcee-ai/mergekit/blob/main/mergekit/merge_methods/model_stock.py 125 | @merge_method 126 | def model_stock( 127 | *deltas: Parameter(Tensor, "delta"), 128 | cos_eps: Parameter(float) = 1e-6, 129 | ) -> Return(Tensor, "delta"): 130 | w_avg = sum(deltas) / len(deltas) 131 | t = get_model_stock_t(deltas, cos_eps) 132 | return torch.nan_to_num(t * w_avg) 133 | 134 | 135 | def get_model_stock_t(deltas: Sequence[Tensor], cos_eps: float): 136 | """ 137 | Approximate solution from mergekit: average of cos(theta). 138 | The expected value is 0, which accidentally corresponds with the implementation from the paper. 139 | This may be very unstable and the range is restricted to [-1, 1]. 140 | """ 141 | n = len(deltas) 142 | 143 | cos = torch.nn.CosineSimilarity(dim=-1, eps=cos_eps) 144 | cos_thetas = [cos(deltas[i], deltas[i + 1]) for i, _ in enumerate(deltas) if (i + 1) < n] 145 | cos_theta = torch.stack(cos_thetas).mean(dim=0) 146 | 147 | t = (n * cos_theta / (1 + (n - 1) * cos_theta)).unsqueeze(-1) 148 | return t 149 | 150 | 151 | # src: https://github.com/krishnap25/geom_median/blob/main/src/geom_median/torch/weiszfeld_list_of_array.py 152 | @merge_method 153 | def geometric_median( 154 | *models: Parameter(Tensor), 155 | eps: Parameter(float) = 1e-6, 156 | maxiter: Parameter(int) = 100, 157 | ftol: Parameter(float) = 1e-20, 158 | ) -> Return(Tensor): 159 | median = weighted_average(models) 160 | weights = new_weights = torch.ones(len(models), device=models[0].device, dtype=models[0].dtype) 161 | objective_value = geometric_median_objective(median, models, weights) 162 | 163 | # Weiszfeld iterations 164 | for _ in range(max(0, maxiter)): 165 | prev_obj_value = objective_value 166 | denom = torch.stack([torch.dist(p, median) for p in models]) 167 | new_weights = weights / torch.clamp(denom, min=eps) 168 | median = weighted_average(models, new_weights) 169 | 170 | objective_value = geometric_median_objective(median, models, weights) 171 | if abs(prev_obj_value - objective_value) <= ftol * objective_value: 172 | break 173 | 174 | return weighted_average(models, new_weights) 175 | 176 | 177 | def weighted_average( 178 | points: Sequence[float | Tensor] | Tensor, 179 | weights: Optional[Sequence[float | Tensor] | Tensor] = None 180 | ) -> float | Tensor: 181 | if weights is not None: 182 | return sum(p * weights[i] for i, p in enumerate(points)) / sum(weights) 183 | else: 184 | return sum(points) / len(points) 185 | 186 | 187 | def geometric_median_objective(median, points: Tuple, weights): 188 | return torch.mean(torch.stack([torch.dist(point, median) for point in points]) * weights) 189 | 190 | 191 | @merge_method 192 | def dropout( # aka n-supermario 193 | *deltas: Parameter(Tensor, "delta"), 194 | probability: Parameter(float) = 0.9, 195 | rescale: Parameter(float) = 1.0, 196 | overlap: Parameter(float) = 1.0, 197 | overlap_emphasis: Parameter(float) = 0.0, 198 | seed: Parameter(int) = None, 199 | ) -> Return(Tensor, "delta"): 200 | if len(deltas) == 0: 201 | return 0 202 | 203 | delta0 = deltas[0] 204 | deltas = torch.stack(deltas) 205 | rng = np.random.default_rng(seed) 206 | 207 | if overlap % 2 == 1: 208 | masks = tuple( 209 | torch.from_numpy(rng.binomial(n=1, p=1 - probability, size=delta0.shape)).to(device=delta0.device, dtype=torch.bool) 210 | for _ in range(len(deltas)) 211 | ) 212 | else: 213 | ks = np.arange(2 ** len(deltas)) 214 | pmf = overlapping_sets_pmf(len(deltas), probability, overlap, overlap_emphasis) 215 | masks = torch.from_numpy(rng.choice(ks, size=delta0.shape, p=pmf)).to(delta0.device) 216 | masks = torch.stack([masks & 2 ** i != 0 for i in range(len(deltas))]) 217 | 218 | final_delta = torch.zeros_like(delta0) 219 | for mask, delta in zip(masks, deltas): 220 | final_delta[mask] += delta[mask] 221 | 222 | if probability == 1.0: 223 | rescalar = 1.0 224 | else: 225 | rescalar = (1.0 - probability) ** rescale 226 | rescalar = rescalar if math.isfinite(rescalar) else 1 227 | return final_delta / masks.sum(0).clamp(1) / rescalar 228 | 229 | 230 | def overlapping_sets_pmf(n, p, overlap: float, overlap_emphasis): 231 | if np.isclose(overlap, round(overlap)): 232 | if round(overlap) % 2 == 0: 233 | pmf = np.array([1/n*float(bin(i).count("1") == 1) for i in range(1, 2**n)]) 234 | else: 235 | pmf = np.array([0 for _ in range(1, 2**n - 1)] + [1]) 236 | else: 237 | if math.floor(overlap) % 2 == 1: 238 | overlap = -overlap 239 | 240 | tan_overlap = np.tan(np.pi * (overlap - 0.5)) 241 | pmf = np.zeros(2 ** n - 1) 242 | for i in range(1, 2 ** n): 243 | num_sets = bin(i).count("1") 244 | pmf[i-1] = tan_overlap*(num_sets - n/2) 245 | pmf = np.exp(pmf) / np.sum(np.exp(pmf)) 246 | 247 | binomial_pmf = binom.pmf(np.arange(1, n + 1), n, p) 248 | expanded_binomial_pmf = np.zeros(2 ** n - 1) 249 | for i in range(1, 2 ** n): 250 | num_sets = bin(i).count("1") 251 | expanded_binomial_pmf[i-1] = binomial_pmf[num_sets-1] / binomial_coefficient_np(n, num_sets) 252 | expanded_binomial_pmf /= expanded_binomial_pmf.sum() 253 | 254 | pmf = torch.lerp( 255 | pmf, 256 | torch.lerp(pmf, expanded_binomial_pmf, 1-abs(2*overlap-1)), 257 | overlap_emphasis, 258 | ) 259 | return np.concatenate([[p], pmf * (1 - p)]) 260 | 261 | 262 | def binomial_coefficient_np(n, k): 263 | if k > n - k: 264 | k = n - k 265 | result = np.int64(1) 266 | for i in range(1, k+1): 267 | result = result * (n - i + 1) // i 268 | return result 269 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/model_configs/__init__.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from sd_mecha.extensions import model_configs 3 | 4 | 5 | yaml_directory = pathlib.Path(__file__).parent 6 | 7 | 8 | def _register_configs(): 9 | for yaml in yaml_directory.glob("*.yaml"): 10 | config = model_configs.YamlModelConfig(yaml) 11 | model_configs.register(config) 12 | 13 | 14 | _register_configs() 15 | from .convert_sdxl_kohya_to_original import convert_sdxl_kohya_to_original 16 | from .convert_sdxl_diffusers_unet_to_original import convert_sdxl_diffusers_unet_to_original 17 | from .convert_sd1_kohya_to_original import convert_sd1_kohya_to_original 18 | from . import convert_sdxl_blocks 19 | from . import convert_sd1_blocks 20 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/model_configs/convert_flux.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from sd_mecha import merge_method, Parameter, Return, StateDict, StateDictKeyError 3 | 4 | 5 | @merge_method(is_conversion=True) 6 | def convert_flux_to_backbone_only( 7 | flux: Parameter(StateDict[Tensor], model_config="flux-flux"), 8 | **kwargs, 9 | ) -> Return(Tensor, model_config="flux-flux_diffuser_only"): 10 | diffuser_only_key = kwargs["key"] 11 | full_key = f"model.diffusion_model.{diffuser_only_key}" 12 | return flux[full_key] 13 | 14 | 15 | @merge_method(is_conversion=True) 16 | def convert_flux_backbone_to_full( 17 | flux_diffuser: Parameter(StateDict[Tensor], model_config="flux-flux_diffuser_only"), 18 | **kwargs, 19 | ) -> Return(Tensor, model_config="flux-flux"): 20 | full_key = kwargs["key"] 21 | if not full_key.startswith("model.diffusion_model."): 22 | raise StateDictKeyError(full_key) 23 | 24 | diffuser_only_key = full_key[len("model.diffusion_model."):] 25 | return flux_diffuser[diffuser_only_key] 26 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/model_configs/convert_huggingface_sd_vae_to_original.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sd_mecha.extensions.merge_methods import StateDict 3 | 4 | 5 | def convert_vae(huggingface_sd: StateDict[torch.Tensor], ldm_key: str) -> torch.Tensor: 6 | huggingface_key = ".".join(ldm_key.split(".")[1:]) 7 | 8 | needs_reshape = False 9 | for sd_weight_name, fake_weight_name in vae_extra_conversion_map.items(): 10 | if f"mid.attn_1.{sd_weight_name}.weight" in huggingface_key or f"mid.attn_1.{sd_weight_name}.bias" in huggingface_key: 11 | needs_reshape = True 12 | huggingface_key = huggingface_key.replace(sd_weight_name, fake_weight_name) 13 | 14 | for weight_name in vae_extra_conversion_map.values(): 15 | if f"mid.attn_1.{weight_name}.weight" in huggingface_key: 16 | needs_reshape = True 17 | 18 | if "attentions" in huggingface_key: 19 | for sd_part, hf_part in vae_conversion_map_attn.items(): 20 | huggingface_key = huggingface_key.replace(sd_part, hf_part) 21 | 22 | for sd_part, hf_part in vae_conversion_map.items(): 23 | huggingface_key = huggingface_key.replace(sd_part, hf_part) 24 | 25 | huggingface_key = f"vae.{huggingface_key}" 26 | res = huggingface_sd[huggingface_key] 27 | if needs_reshape: 28 | res = reshape_weight_for_sd(res) 29 | return res 30 | 31 | 32 | def reshape_weight_for_sd(w): 33 | # convert HF linear weights to SD conv2d weights 34 | if w.ndim != 1: 35 | return w.reshape(*w.shape, 1, 1) 36 | else: 37 | return w 38 | 39 | 40 | vae_conversion_map = { 41 | # (original, huggingface) 42 | "nin_shortcut": "conv_shortcut", 43 | "norm_out": "conv_norm_out", 44 | "mid.attn_1.": "mid_block.attentions.0.", 45 | } 46 | for i in range(4): 47 | # down_blocks have two resnets 48 | for j in range(2): 49 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 50 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 51 | vae_conversion_map[sd_down_prefix] = hf_down_prefix 52 | 53 | if i < 3: 54 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 55 | sd_downsample_prefix = f"down.{i}.downsample." 56 | vae_conversion_map[sd_downsample_prefix] = hf_downsample_prefix 57 | 58 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 59 | sd_upsample_prefix = f"up.{3-i}.upsample." 60 | vae_conversion_map[sd_upsample_prefix] = hf_upsample_prefix 61 | 62 | # up_blocks have three resnets 63 | # also, up blocks in hf are numbered in reverse from sd 64 | for j in range(3): 65 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 66 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 67 | vae_conversion_map[sd_up_prefix] = hf_up_prefix 68 | 69 | # this part accounts for mid blocks in both the encoder and the decoder 70 | for i in range(2): 71 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 72 | sd_mid_res_prefix = f"mid.block_{i+1}." 73 | vae_conversion_map[sd_mid_res_prefix] = hf_mid_res_prefix 74 | 75 | 76 | vae_conversion_map_attn = { 77 | # (original, huggingface) 78 | "norm.": "group_norm.", 79 | "q.": "query.", 80 | "k.": "key.", 81 | "v.": "value.", 82 | "proj_out.": "proj_attn.", 83 | } 84 | 85 | 86 | # This is probably not the most ideal solution, but it does work. 87 | vae_extra_conversion_map = { 88 | # (original, huggingface) 89 | "q": "to_q", 90 | "k": "to_k", 91 | "v": "to_v", 92 | "proj_out": "to_out.0", 93 | } 94 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/model_configs/convert_sd1_blocks.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import TypeVar 3 | from sd_mecha.extensions.merge_methods import Return, Parameter, StateDict, merge_method 4 | 5 | 6 | re_inp = re.compile(r"\.input_blocks\.(\d+)\.") # 12 7 | re_mid = re.compile(r"\.middle_block\.(\d+)\.") # 1 8 | re_out = re.compile(r"\.output_blocks\.(\d+)\.") # 12 9 | 10 | 11 | T = TypeVar("T") 12 | 13 | 14 | # srcs, in order of priority when disagreements occur: 15 | # - https://github.com/hako-mikan/sd-webui-supermerger/blob/f14b3e5d0be9c510d199cca502c4148160f901bb/scripts/mergers/mergers.py#L1376 16 | # - https://github.com/s1dlx/meh/blob/04af2c8d63744fb6c02d35d328a2c84380cca444/sd_meh/merge.py#L360 17 | # - https://github.com/vladmandic/automatic/blob/e22d0789bddd3894364b0d59a4c9b3e456e89079/modules/merging/merge_utils.py#L64 18 | @merge_method(is_conversion=True) 19 | def convert_sd1_blocks_to_ldm( 20 | blocks: Parameter(StateDict[T], model_config="sd1-supermerger_blocks"), 21 | **kwargs, 22 | ) -> Return(T, model_config="sd1-ldm"): 23 | sgm_key = kwargs["key"] 24 | 25 | block_key = "BASE" 26 | if sgm_key.startswith("model.diffusion_model."): 27 | block_key = "OUT11" 28 | if ".time_embed" in sgm_key: 29 | block_key = "BASE" # before input blocks 30 | elif ".out." in sgm_key: 31 | block_key = "OUT11" # after output blocks 32 | elif m := re_inp.search(sgm_key): 33 | block_key = f"IN{int(m.groups(1)[0]):02}" 34 | elif re_mid.search(sgm_key): 35 | block_key = "M00" 36 | elif m := re_out.search(sgm_key): 37 | block_key = f"OUT{int(m.groups(1)[0]):02}" 38 | 39 | return blocks[block_key] 40 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/model_configs/convert_sd1_kohya_to_original.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sd_mecha.extensions.merge_methods import merge_method, StateDict, Parameter, Return 3 | from .convert_huggingface_sd_vae_to_original import convert_vae 4 | from ... import model_configs 5 | 6 | # hf to sd conversion src: 7 | # https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py 8 | 9 | 10 | sd1_kohya = model_configs.resolve("sd1-kohya") 11 | sd1_ldm = model_configs.resolve("sd1-ldm") 12 | 13 | 14 | @merge_method( 15 | identifier=f"convert_'{sd1_kohya.identifier}'_to_'{sd1_ldm.identifier}'", 16 | is_conversion=True, 17 | ) 18 | def convert_sd1_kohya_to_original( 19 | kohya_sd: Parameter(StateDict[torch.Tensor], model_config=sd1_kohya), 20 | **kwargs, 21 | ) -> Return(torch.Tensor, model_config=sd1_ldm): 22 | ldm_key = kwargs["key"] 23 | if ldm_key.startswith("model.diffusion_model."): 24 | return convert_unet(kohya_sd, ldm_key) 25 | elif ldm_key.startswith("cond_stage_model."): 26 | return convert_clip_l(kohya_sd, ldm_key) 27 | elif ldm_key.startswith("first_stage_model."): 28 | return convert_vae(kohya_sd, ldm_key) 29 | else: 30 | return kohya_sd[ldm_key] 31 | 32 | 33 | def convert_unet(kohya_sd: StateDict, ldm_key: str) -> torch.Tensor: 34 | kohya_key = '.'.join(ldm_key.split(".")[2:]) # model.diffusion_model. 35 | 36 | for sd_part, hf_part in unet_conversion_map_layer.items(): 37 | kohya_key = kohya_key.replace(sd_part, hf_part) 38 | 39 | if "resnets" in kohya_key: 40 | for sd_part, hf_part in unet_conversion_map_resnet.items(): 41 | kohya_key = kohya_key.replace(sd_part, hf_part) 42 | 43 | kohya_key = unet_conversion_map.get(kohya_key, kohya_key) 44 | 45 | kohya_key = f"unet.{kohya_key}" 46 | return kohya_sd[kohya_key] 47 | 48 | 49 | unet_conversion_map = { 50 | # (original: huggingface) 51 | "time_embed.0.weight": "time_embedding.linear_1.weight", 52 | "time_embed.0.bias": "time_embedding.linear_1.bias", 53 | "time_embed.2.weight": "time_embedding.linear_2.weight", 54 | "time_embed.2.bias": "time_embedding.linear_2.bias", 55 | "input_blocks.0.0.weight": "conv_in.weight", 56 | "input_blocks.0.0.bias": "conv_in.bias", 57 | "out.0.weight": "conv_norm_out.weight", 58 | "out.0.bias": "conv_norm_out.bias", 59 | "out.2.weight": "conv_out.weight", 60 | "out.2.bias": "conv_out.bias", 61 | } 62 | 63 | 64 | unet_conversion_map_resnet = { 65 | # (original: huggingface) 66 | "in_layers.0": "norm1", 67 | "in_layers.2": "conv1", 68 | "out_layers.0": "norm2", 69 | "out_layers.3": "conv2", 70 | "emb_layers.1": "time_emb_proj", 71 | "skip_connection": "conv_shortcut", 72 | } 73 | 74 | 75 | unet_conversion_map_layer = {} 76 | # hardcoded number of downblocks and resnets/attentions... 77 | # would need smarter logic for other networks. 78 | for i in range(4): 79 | # loop over downblocks/upblocks 80 | 81 | for j in range(2): 82 | # loop over resnets/attentions for downblocks 83 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 84 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 85 | unet_conversion_map_layer[sd_down_res_prefix] = hf_down_res_prefix 86 | 87 | if i < 3: 88 | # no attention layers in down_blocks.3 89 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 90 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 91 | unet_conversion_map_layer[sd_down_atn_prefix] = hf_down_atn_prefix 92 | 93 | for j in range(3): 94 | # loop over resnets/attentions for upblocks 95 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 96 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 97 | unet_conversion_map_layer[sd_up_res_prefix] = hf_up_res_prefix 98 | 99 | if i > 0: 100 | # no attention layers in up_blocks.0 101 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 102 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 103 | unet_conversion_map_layer[sd_up_atn_prefix] = hf_up_atn_prefix 104 | 105 | if i < 3: 106 | # no downsample in down_blocks.3 107 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 108 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 109 | unet_conversion_map_layer[sd_downsample_prefix] = hf_downsample_prefix 110 | 111 | # no upsample in up_blocks.3 112 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 113 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." 114 | unet_conversion_map_layer[sd_upsample_prefix] = hf_upsample_prefix 115 | 116 | hf_mid_atn_prefix = "mid_block.attentions.0." 117 | sd_mid_atn_prefix = "middle_block.1." 118 | unet_conversion_map_layer[sd_mid_atn_prefix] = hf_mid_atn_prefix 119 | 120 | for j in range(2): 121 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 122 | sd_mid_res_prefix = f"middle_block.{2*j}." 123 | unet_conversion_map_layer[sd_mid_res_prefix] = hf_mid_res_prefix 124 | 125 | 126 | def convert_clip_l(kohya_sd: StateDict, ldm_key: str) -> torch.Tensor: 127 | kohya_key = '.'.join(ldm_key.split(".")[2:]) # cond_stage_model.transformer. 128 | kohya_key = f"te.{kohya_key}" 129 | return kohya_sd[kohya_key] 130 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/model_configs/convert_sdxl_blocks.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import TypeVar 3 | from sd_mecha.extensions.merge_methods import Return, Parameter, StateDict, merge_method 4 | 5 | 6 | re_inp = re.compile(r"\.input_blocks\.(\d+)\.") # 12 7 | re_mid = re.compile(r"\.middle_block\.(\d+)\.") # 1 8 | re_out = re.compile(r"\.output_blocks\.(\d+)\.") # 12 9 | 10 | 11 | T = TypeVar("T") 12 | 13 | 14 | # srcs, in order of priority when disagreements occur: 15 | # - https://github.com/hako-mikan/sd-webui-supermerger/blob/f14b3e5d0be9c510d199cca502c4148160f901bb/scripts/mergers/mergers.py#L1376 16 | # - https://github.com/s1dlx/meh/blob/04af2c8d63744fb6c02d35d328a2c84380cca444/sd_meh/merge.py#L360 17 | # - https://github.com/vladmandic/automatic/blob/e22d0789bddd3894364b0d59a4c9b3e456e89079/modules/merging/merge_utils.py#L64 18 | @merge_method(is_conversion=True) 19 | def convert_sdxl_blocks_to_sgm( 20 | blocks: Parameter(StateDict[T], model_config="sdxl-supermerger_blocks"), 21 | **kwargs, 22 | ) -> Return(T, model_config="sdxl-sgm"): 23 | sgm_key = kwargs["key"] 24 | 25 | block_key = "BASE" 26 | if sgm_key.startswith("model.diffusion_model."): 27 | block_key = "OUT08" 28 | if ".time_embed" in sgm_key or ".label_emb" in sgm_key: 29 | block_key = "BASE" # before input blocks 30 | elif ".out." in sgm_key: 31 | block_key = "OUT08" # after output blocks 32 | elif m := re_inp.search(sgm_key): 33 | block_key = f"IN{int(m.groups(1)[0]):02}" 34 | elif re_mid.search(sgm_key): 35 | block_key = "M00" 36 | elif m := re_out.search(sgm_key): 37 | block_key = f"OUT{int(m.groups(1)[0]):02}" 38 | elif sgm_key.startswith("first_stage_model."): 39 | block_key = "VAE" 40 | 41 | return blocks[block_key] 42 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/model_configs/convert_sdxl_diffusers_unet_to_original.py: -------------------------------------------------------------------------------- 1 | from sd_mecha import merge_method, Parameter, Return, StateDict, StateDictKeyError 2 | from sd_mecha.extensions import model_configs 3 | from torch import Tensor 4 | 5 | 6 | # hf to sd conversion src: 7 | # https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py 8 | 9 | 10 | sdxl_diffusers_unet_config = model_configs.resolve('sdxl-diffusers_unet_only') 11 | sdxl_sgm_config = model_configs.resolve('sdxl-sgm') 12 | 13 | 14 | @merge_method( 15 | identifier=f"convert_'{sdxl_diffusers_unet_config.identifier}'_to_'{sdxl_sgm_config.identifier}'", 16 | is_conversion=True, 17 | ) 18 | def convert_sdxl_diffusers_unet_to_original( 19 | diffusers_sd: Parameter(StateDict[Tensor], model_config=sdxl_diffusers_unet_config), 20 | **kwargs, 21 | ) -> Return(Tensor, model_config=sdxl_sgm_config): 22 | sgm_key = kwargs["key"] 23 | if not sgm_key.startswith("model.diffusion_model"): 24 | raise StateDictKeyError(sgm_key) 25 | 26 | kohya_key = ".".join(sgm_key.split(".")[2:]) # model.diffusion_model. 27 | 28 | for sd_part, hf_part in unet_conversion_map_layer.items(): 29 | kohya_key = kohya_key.replace(sd_part, hf_part) 30 | 31 | if "resnets" in kohya_key: 32 | for sd_part, hf_part in unet_conversion_map_resnet.items(): 33 | kohya_key = kohya_key.replace(sd_part, hf_part) 34 | 35 | kohya_key = unet_conversion_map.get(kohya_key, kohya_key) 36 | return diffusers_sd[kohya_key] 37 | 38 | 39 | unet_conversion_map = { 40 | # (original: huggingface) 41 | "time_embed.0.weight": "time_embedding.linear_1.weight", 42 | "time_embed.0.bias": "time_embedding.linear_1.bias", 43 | "time_embed.2.weight": "time_embedding.linear_2.weight", 44 | "time_embed.2.bias": "time_embedding.linear_2.bias", 45 | "input_blocks.0.0.weight": "conv_in.weight", 46 | "input_blocks.0.0.bias": "conv_in.bias", 47 | "out.0.weight": "conv_norm_out.weight", 48 | "out.0.bias": "conv_norm_out.bias", 49 | "out.2.weight": "conv_out.weight", 50 | "out.2.bias": "conv_out.bias", 51 | # the following are for sdxl 52 | "label_emb.0.0.weight": "add_embedding.linear_1.weight", 53 | "label_emb.0.0.bias": "add_embedding.linear_1.bias", 54 | "label_emb.0.2.weight": "add_embedding.linear_2.weight", 55 | "label_emb.0.2.bias": "add_embedding.linear_2.bias", 56 | } 57 | 58 | unet_conversion_map_resnet = { 59 | # (original: huggingface) 60 | "in_layers.0": "norm1", 61 | "in_layers.2": "conv1", 62 | "out_layers.0": "norm2", 63 | "out_layers.3": "conv2", 64 | "emb_layers.1": "time_emb_proj", 65 | "skip_connection": "conv_shortcut", 66 | } 67 | 68 | 69 | unet_conversion_map_layer = {} 70 | # hardcoded number of downblocks and resnets/attentions... 71 | # would need smarter logic for other networks. 72 | for i in range(3): 73 | # loop over downblocks/upblocks 74 | 75 | for j in range(2): 76 | # loop over resnets/attentions for downblocks 77 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 78 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 79 | unet_conversion_map_layer[sd_down_res_prefix] = hf_down_res_prefix 80 | 81 | if i > 0: 82 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 83 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 84 | unet_conversion_map_layer[sd_down_atn_prefix] = hf_down_atn_prefix 85 | 86 | for j in range(4): 87 | # loop over resnets/attentions for upblocks 88 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 89 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 90 | unet_conversion_map_layer[sd_up_res_prefix] = hf_up_res_prefix 91 | 92 | if i < 2: 93 | # no attention layers in up_blocks.0 94 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 95 | sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1." 96 | unet_conversion_map_layer[sd_up_atn_prefix] = hf_up_atn_prefix 97 | 98 | if i < 3: 99 | # no downsample in down_blocks.3 100 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 101 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 102 | unet_conversion_map_layer[sd_downsample_prefix] = hf_downsample_prefix 103 | 104 | # no upsample in up_blocks.3 105 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 106 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." 107 | unet_conversion_map_layer[sd_upsample_prefix] = hf_upsample_prefix 108 | unet_conversion_map_layer["output_blocks.2.2.conv."] = "output_blocks.2.1.conv." 109 | 110 | hf_mid_atn_prefix = "mid_block.attentions.0." 111 | sd_mid_atn_prefix = "middle_block.1." 112 | unet_conversion_map_layer[sd_mid_atn_prefix] = hf_mid_atn_prefix 113 | for j in range(2): 114 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 115 | sd_mid_res_prefix = f"middle_block.{2*j}." 116 | unet_conversion_map_layer[sd_mid_res_prefix] = hf_mid_res_prefix 117 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/model_configs/convert_sdxl_kohya_to_original.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from sd_mecha import merge_method, StateDict, Return, Parameter 4 | from sd_mecha.extensions import model_configs 5 | from .convert_huggingface_sd_vae_to_original import convert_vae 6 | 7 | 8 | sdxl_kohya_config = model_configs.resolve('sdxl-kohya') 9 | sdxl_sgm_config = model_configs.resolve('sdxl-sgm') 10 | 11 | 12 | @merge_method( 13 | identifier=f"convert_'{sdxl_kohya_config.identifier}'_to_'{sdxl_sgm_config.identifier}'", 14 | is_conversion=True, 15 | ) 16 | def convert_sdxl_kohya_to_original( 17 | kohya_sd: Parameter(StateDict[Tensor], model_config=sdxl_kohya_config), 18 | **kwargs, 19 | ) -> Return(Tensor, model_config=sdxl_sgm_config): 20 | sgm_key = kwargs["key"] 21 | 22 | if sgm_key.startswith("model.diffusion_model."): 23 | kohya_key = sgm_key.replace("model.diffusion_model.", "unet.") 24 | return kohya_sd[kohya_key] 25 | elif sgm_key.startswith("conditioner.embedders.0."): 26 | kohya_key = sgm_key.replace("conditioner.embedders.0.transformer.", "te1.") 27 | return kohya_sd[kohya_key] 28 | elif sgm_key.startswith("conditioner.embedders.1."): 29 | if sgm_key.endswith("text_projection"): 30 | kohya_key = "te2.text_projection.weight" 31 | else: 32 | kohya_key = sgm_key.replace("conditioner.embedders.1.model.", "te2.text_model.") 33 | kohya_key = kohya_key.replace(".token_embedding.", ".embeddings.token_embedding.") 34 | kohya_key = kohya_key.replace(".positional_embedding", ".embeddings.position_embedding.weight") 35 | kohya_key = kohya_key.replace(".transformer.resblocks.", ".encoder.layers.") 36 | kohya_key = kohya_key.replace(".attn.", ".self_attn.") 37 | kohya_key = kohya_key.replace(".mlp.c_fc.", ".mlp.fc1.") 38 | kohya_key = kohya_key.replace(".mlp.c_proj.", ".mlp.fc2.") 39 | kohya_key = kohya_key.replace(".ln_final.", ".final_layer_norm.") 40 | kohya_key = kohya_key.replace(".ln_", ".layer_norm") 41 | 42 | if kohya_key.endswith((".in_proj_weight", ".in_proj_bias")): 43 | is_bias = kohya_key.endswith("bias") 44 | partial_key = kohya_key.replace(".in_proj_weight", "").replace(".in_proj_bias", "") 45 | res = torch.vstack([ 46 | kohya_sd[f"{partial_key}.{k}_proj.{'bias' if is_bias else 'weight'}"] 47 | for k in ("q", "k", "v") 48 | ]) 49 | else: 50 | res = kohya_sd[kohya_key] 51 | 52 | if sgm_key.endswith("text_projection"): 53 | res = res.T 54 | 55 | return res 56 | elif sgm_key.startswith("first_stage_model."): 57 | return convert_vae(kohya_sd, sgm_key) 58 | else: 59 | return kohya_sd[sgm_key] 60 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/model_configs/sd1-supermerger_blocks.yaml: -------------------------------------------------------------------------------- 1 | identifier: sd1-supermerger_blocks 2 | components: 3 | blocks: 4 | BASE: {shape: [], dtype: float32} 5 | IN00: {shape: [], dtype: float32} 6 | IN01: {shape: [], dtype: float32} 7 | IN02: {shape: [], dtype: float32} 8 | IN03: {shape: [], dtype: float32} 9 | IN04: {shape: [], dtype: float32} 10 | IN05: {shape: [], dtype: float32} 11 | IN06: {shape: [], dtype: float32} 12 | IN07: {shape: [], dtype: float32} 13 | IN08: {shape: [], dtype: float32} 14 | IN09: {shape: [], dtype: float32} 15 | IN10: {shape: [], dtype: float32} 16 | IN11: {shape: [], dtype: float32} 17 | M00: {shape: [], dtype: float32} 18 | OUT00: {shape: [], dtype: float32} 19 | OUT01: {shape: [], dtype: float32} 20 | OUT02: {shape: [], dtype: float32} 21 | OUT03: {shape: [], dtype: float32} 22 | OUT04: {shape: [], dtype: float32} 23 | OUT05: {shape: [], dtype: float32} 24 | OUT06: {shape: [], dtype: float32} 25 | OUT07: {shape: [], dtype: float32} 26 | OUT08: {shape: [], dtype: float32} 27 | OUT09: {shape: [], dtype: float32} 28 | OUT10: {shape: [], dtype: float32} 29 | OUT11: {shape: [], dtype: float32} 30 | -------------------------------------------------------------------------------- /sd_mecha/extensions/builtin/model_configs/sdxl-supermerger_blocks.yaml: -------------------------------------------------------------------------------- 1 | identifier: sdxl-supermerger_blocks 2 | components: 3 | blocks: 4 | BASE: {shape: [], dtype: float32} 5 | IN00: {shape: [], dtype: float32} 6 | IN01: {shape: [], dtype: float32} 7 | IN02: {shape: [], dtype: float32} 8 | IN03: {shape: [], dtype: float32} 9 | IN04: {shape: [], dtype: float32} 10 | IN05: {shape: [], dtype: float32} 11 | IN06: {shape: [], dtype: float32} 12 | IN07: {shape: [], dtype: float32} 13 | IN08: {shape: [], dtype: float32} 14 | M00: {shape: [], dtype: float32} 15 | OUT00: {shape: [], dtype: float32} 16 | OUT01: {shape: [], dtype: float32} 17 | OUT02: {shape: [], dtype: float32} 18 | OUT03: {shape: [], dtype: float32} 19 | OUT04: {shape: [], dtype: float32} 20 | OUT05: {shape: [], dtype: float32} 21 | OUT06: {shape: [], dtype: float32} 22 | OUT07: {shape: [], dtype: float32} 23 | OUT08: {shape: [], dtype: float32} 24 | VAE: {shape: [], dtype: float32} 25 | -------------------------------------------------------------------------------- /sd_mecha/extensions/merge_spaces.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import List, Dict, Set, Tuple 3 | import fuzzywuzzy.process 4 | 5 | 6 | class MergeSpace: 7 | def __init__(self, identifier: str): 8 | self.identifier = identifier 9 | 10 | def __eq__(self, other): 11 | if isinstance(other, str): 12 | other = MergeSpace(other) 13 | return self.identifier == other.identifier 14 | 15 | def __hash__(self): 16 | return hash(self.identifier) 17 | 18 | def __repr__(self): 19 | return f"MergeSpace('{self.identifier}')" 20 | 21 | 22 | class MergeSpaceSymbol: 23 | def __init__(self, *merge_spaces: Tuple[str | MergeSpace, ...]): 24 | self.merge_spaces = { 25 | MergeSpace(merge_space) if isinstance(merge_space, str) else merge_space 26 | for merge_space in merge_spaces 27 | } 28 | 29 | 30 | AnyMergeSpace = Set[MergeSpace] | MergeSpaceSymbol 31 | 32 | 33 | def register_merge_space(identifier: str): 34 | if identifier in _merge_space_registry: 35 | raise ValueError(f"merge space {identifier} already exists") 36 | _merge_space_registry[identifier] = MergeSpace(identifier) 37 | 38 | 39 | def resolve(identifier: str) -> MergeSpace: 40 | try: 41 | return _merge_space_registry[identifier] 42 | except KeyError as e: 43 | suggestion = fuzzywuzzy.process.extractOne(str(e), _merge_space_registry.keys())[0] 44 | raise ValueError(f"unknown merge space: {e}. Nearest match is '{suggestion}'") 45 | 46 | 47 | def get_identifiers(merge_space: AnyMergeSpace) -> List[str]: 48 | if isinstance(merge_space, Set): 49 | return [merge_space.identifier for merge_space in merge_space] 50 | elif isinstance(merge_space, MergeSpaceSymbol): 51 | return get_identifiers(merge_space.merge_spaces) 52 | else: 53 | raise TypeError(f"expected {MergeSpaceSymbol.__name__} or Tuple[{MergeSpace.__name__}, ...], got {type(merge_space)}") 54 | 55 | 56 | def get_all() -> Set[MergeSpace]: 57 | return set(_merge_space_registry.values()) 58 | 59 | 60 | @functools.cache 61 | def _register_builtin_merge_spaces(): 62 | global _builtin_merge_spaces 63 | for builtin_merge_space in _builtin_merge_spaces: 64 | register_merge_space(builtin_merge_space) 65 | 66 | 67 | _merge_space_registry: Dict[str, MergeSpace] = {} 68 | _builtin_merge_spaces = [ 69 | "weight", 70 | "delta", 71 | "param", 72 | ] 73 | _register_builtin_merge_spaces() 74 | -------------------------------------------------------------------------------- /sd_mecha/extensions/model_configs.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import dataclasses 3 | import inspect 4 | import pathlib 5 | import re 6 | import fuzzywuzzy.process 7 | import torch 8 | import yaml 9 | from collections import OrderedDict 10 | from sd_mecha.streaming import TensorMetadata 11 | from typing import Dict, List, Iterable, Mapping, Protocol, runtime_checkable, Optional 12 | try: 13 | from yaml import CLoader as YamlLoader 14 | except ImportError: 15 | from yaml import Loader as YamlLoader 16 | 17 | 18 | StateDictKey = str 19 | 20 | 21 | @dataclasses.dataclass 22 | class KeyMetadata: 23 | shape: Optional[List[int]] | torch.Size 24 | dtype: Optional[str] | torch.dtype 25 | aliases: Iterable[str] = dataclasses.field(default_factory=tuple, metadata={"exclude": lambda p: bool(p)}) 26 | optional: bool = dataclasses.field(default=False, metadata={"exclude": lambda p: not p}) 27 | 28 | def __post_init__(self): 29 | if isinstance(self.shape, list): 30 | self.shape = torch.Size(self.shape) 31 | if isinstance(self.dtype, str): 32 | self.dtype = getattr(torch, self.dtype) 33 | 34 | def metadata(self) -> TensorMetadata: 35 | return TensorMetadata(self.shape, self.dtype) 36 | 37 | 38 | @dataclasses.dataclass 39 | class ModelComponent: 40 | _keys: Mapping[StateDictKey, KeyMetadata] = dataclasses.field(metadata={"serial_name": "keys"}) 41 | 42 | def __post_init__(self): 43 | keys = OrderedDict() 44 | for k, v in self._keys.items(): 45 | if isinstance(v, Mapping): 46 | keys[k] = KeyMetadata(**v) 47 | else: 48 | keys[k] = v 49 | self._keys = keys 50 | 51 | def keys(self) -> Mapping[StateDictKey, KeyMetadata]: 52 | return OrderedDict( 53 | (k, v) 54 | for k, v in self._keys.items() 55 | ) 56 | 57 | def metadata(self) -> Mapping[StateDictKey, TensorMetadata]: 58 | return OrderedDict( 59 | (k, v.metadata()) 60 | for k, v in self._keys.items() 61 | ) 62 | 63 | def aliases(self) -> Mapping[StateDictKey, Iterable[StateDictKey]]: 64 | return OrderedDict( 65 | (k, v.aliases) 66 | for k, v in self._keys.items() 67 | ) 68 | 69 | 70 | @runtime_checkable 71 | class ModelConfig(Protocol): 72 | def eq(self, other): 73 | other_identifier = getattr(other, "identifier", None) 74 | return self.identifier == other_identifier 75 | 76 | def __repr__(self): 77 | return f"" 78 | 79 | @property 80 | @abc.abstractmethod 81 | def identifier(self) -> str: 82 | ... 83 | 84 | @abc.abstractmethod 85 | def get_architecture_identifier(self) -> str: 86 | ... 87 | 88 | @abc.abstractmethod 89 | def get_implementation_identifier(self) -> str: 90 | ... 91 | 92 | @abc.abstractmethod 93 | def components(self) -> Mapping[str, ModelComponent]: 94 | ... 95 | 96 | @abc.abstractmethod 97 | def keys(self) -> Mapping[StateDictKey, KeyMetadata]: 98 | ... 99 | 100 | @abc.abstractmethod 101 | def metadata(self) -> Mapping[StateDictKey, TensorMetadata]: 102 | ... 103 | 104 | @abc.abstractmethod 105 | def aliases(self) -> Mapping[StateDictKey, Iterable[StateDictKey]]: 106 | ... 107 | 108 | 109 | @dataclasses.dataclass 110 | class ModelConfigImpl(ModelConfig): 111 | _identifier: str = dataclasses.field(metadata={"serial_name": "identifier"}) 112 | _components: Mapping[str, ModelComponent] = dataclasses.field(metadata={"serial_name": "components"}) 113 | 114 | _keys_cache: Mapping[StateDictKey, KeyMetadata] = dataclasses.field(default=None, init=False, hash=False, compare=False, metadata={"exclude": True}) 115 | _metadata_cache: Mapping[StateDictKey, TensorMetadata] = dataclasses.field(default=None, init=False, hash=False, compare=False, metadata={"exclude": True}) 116 | _aliases_cache: Mapping[StateDictKey, Iterable[StateDictKey]] = dataclasses.field(default=None, init=False, hash=False, compare=False, metadata={"exclude": True}) 117 | 118 | def __post_init__(self): 119 | if not re.fullmatch("[a-z0-9._+]+-[a-z0-9._+]+", self.identifier): 120 | raise ValueError( 121 | f"Identifier of model {self.identifier} is invalid: " 122 | "it must only contain lowercase alphanumerical characters, '.', '_' or '+', " 123 | "and must match the pattern '-'. " 124 | "An example of valid identifier is 'flux_dev-flux'" 125 | ) 126 | 127 | components = OrderedDict() 128 | for k, v in self._components.items(): 129 | if isinstance(v, Mapping): 130 | components[k] = ModelComponent(v) 131 | else: 132 | components[k] = v 133 | self._components = components 134 | 135 | @property 136 | def identifier(self) -> str: 137 | return self._identifier 138 | 139 | def get_architecture_identifier(self): 140 | return self.identifier.split("-")[0] 141 | 142 | def get_implementation_identifier(self): 143 | return self.identifier.split("-")[1] 144 | 145 | def components(self) -> Mapping[str, ModelComponent]: 146 | return self._components 147 | 148 | def keys(self) -> Mapping[StateDictKey, KeyMetadata]: 149 | if self._keys_cache is None: 150 | self._keys_cache = OrderedDict( 151 | (k, v) 152 | for component in self.components().values() 153 | for k, v in component.keys().items() 154 | ) 155 | return self._keys_cache 156 | 157 | def metadata(self) -> Mapping[StateDictKey, TensorMetadata]: 158 | if self._metadata_cache is None: 159 | self._metadata_cache = OrderedDict( 160 | (k, v) 161 | for component in self.components().values() 162 | for k, v in component.metadata().items() 163 | ) 164 | return self._metadata_cache 165 | 166 | def aliases(self) -> Mapping[StateDictKey, Iterable[StateDictKey]]: 167 | if self._aliases_cache is None: 168 | self._aliases_cache = OrderedDict( 169 | (k, v) 170 | for component in self.components().values() 171 | for k, v in component.aliases().items() 172 | ) 173 | return self._aliases_cache 174 | 175 | 176 | def ModelConfigImpl__init__patch(self, *args, **kwargs): 177 | for field in dataclasses.fields(ModelConfigImpl): 178 | if "serial_name" in field.metadata and field.metadata["serial_name"] in kwargs: 179 | kwargs[field.name] = kwargs.pop(field.metadata["serial_name"]) 180 | 181 | ModelConfigImpl__init__(self, *args, **kwargs) 182 | 183 | 184 | ModelConfigImpl__init__ = ModelConfigImpl.__init__ 185 | ModelConfigImpl.__init__ = ModelConfigImpl__init__patch 186 | 187 | 188 | class LazyModelConfigBase(ModelConfig): 189 | def __init__(self): 190 | self.underlying_config = None 191 | 192 | @classmethod 193 | def __init_subclass__(cls): 194 | super().__init_subclass__() 195 | for name, value in inspect.getmembers(ModelConfig): 196 | if ( 197 | (inspect.isfunction(value) or isinstance(value, property)) and 198 | getattr(value, "__isabstractmethod__", False) and 199 | name not in cls.__dict__ 200 | ): 201 | # implement all remaining abstract methods as delegating to underlying_config 202 | setattr(cls, name, property(lambda self, name=name: resolve_lazy_model_config_attribute(self, name=name))) 203 | 204 | def _ensure_config(self) -> None: 205 | if self.underlying_config is not None: 206 | return 207 | 208 | self.underlying_config = self.create_config() 209 | 210 | @abc.abstractmethod 211 | def create_config(self) -> ModelConfig: 212 | raise NotImplementedError 213 | 214 | 215 | def resolve_lazy_model_config_attribute(self: LazyModelConfigBase, name: str): 216 | self._ensure_config() 217 | attribute = getattr(self.underlying_config, name) 218 | if inspect.ismethod(attribute): 219 | method = attribute.__func__.__get__(self.underlying_config, self.underlying_config.__class__) 220 | return method 221 | return attribute 222 | 223 | 224 | class StructuralModelConfig(ModelConfig): 225 | def __init__(self, keys: Mapping[StateDictKey, KeyMetadata]): 226 | self._keys_cache = keys 227 | self._metadata_cache = None 228 | self._aliases_cache = None 229 | 230 | @property 231 | def identifier(self) -> str: 232 | return "structural" 233 | 234 | def get_architecture_identifier(self) -> str: 235 | return self.identifier 236 | 237 | def get_implementation_identifier(self) -> str: 238 | return "" 239 | 240 | def components(self) -> Mapping[str, ModelComponent]: 241 | return {"keys": ModelComponent(self.keys())} 242 | 243 | def keys(self) -> Mapping[StateDictKey, KeyMetadata]: 244 | return self._keys_cache 245 | 246 | def metadata(self) -> Mapping[StateDictKey, TensorMetadata]: 247 | if self._metadata_cache is None: 248 | self._metadata_cache = OrderedDict( 249 | (k, v.metadata()) 250 | for k, v in self.keys().items() 251 | ) 252 | return self._metadata_cache 253 | 254 | def aliases(self) -> Mapping[StateDictKey, Iterable[StateDictKey]]: 255 | if self._aliases_cache is None: 256 | self._aliases_cache = OrderedDict( 257 | (k, v.aliases) 258 | for k, v in self.keys().items() 259 | ) 260 | return self._aliases_cache 261 | 262 | 263 | class InferModelConfig(ModelConfig): 264 | def eq(self, other): 265 | raise RuntimeError("the config has not yet been inferred") 266 | 267 | @property 268 | def identifier(self) -> str: 269 | return "infer" 270 | 271 | def get_architecture_identifier(self) -> str: 272 | return self.identifier 273 | 274 | def get_implementation_identifier(self) -> str: 275 | return "" 276 | 277 | def components(self) -> Mapping[str, ModelComponent]: 278 | raise RuntimeError("the config has not yet been inferred") 279 | 280 | def keys(self) -> Mapping[StateDictKey, KeyMetadata]: 281 | raise RuntimeError("the config has not yet been inferred") 282 | 283 | def metadata(self) -> Mapping[StateDictKey, TensorMetadata]: 284 | raise RuntimeError("the config has not yet been inferred") 285 | 286 | def aliases(self) -> Mapping[StateDictKey, Iterable[StateDictKey]]: 287 | raise RuntimeError("the config has not yet been inferred") 288 | 289 | 290 | INFER = InferModelConfig() 291 | 292 | 293 | class YamlModelConfig(LazyModelConfigBase): 294 | def __init__(self, yaml_config_file: pathlib.Path): 295 | super().__init__() 296 | self.yaml_config_file = yaml_config_file 297 | self._identifier = yaml_config_file.stem 298 | 299 | @property 300 | def identifier(self) -> str: 301 | return self._identifier 302 | 303 | def create_config(self) -> ModelConfig: 304 | with open(self.yaml_config_file, "r") as f: 305 | yaml_config = f.read() 306 | 307 | return from_yaml(yaml_config) 308 | 309 | 310 | _model_configs_registry_base: Dict[str, ModelConfig] = {} 311 | _model_configs_registry_aux: Dict[str, ModelConfig] = {} 312 | 313 | 314 | def serialize(obj): 315 | if isinstance(obj, ModelComponent): 316 | return serialize(obj.keys) 317 | elif dataclasses.is_dataclass(obj): 318 | return { 319 | field.metadata.get("serial_name", field.name): serialize(getattr(obj, field.name)) 320 | for field in dataclasses.fields(obj) 321 | if not ( 322 | callable(field.metadata.get("exclude")) and field.metadata["exclude"](getattr(obj, field.name)) 323 | or field.metadata.get("exclude", False) 324 | ) 325 | } 326 | elif isinstance(obj, (str, int, float, type(None))): 327 | return obj 328 | elif isinstance(obj, torch.dtype): 329 | return str(obj).split(".")[1] 330 | elif isinstance(obj, torch.Size): 331 | return list(obj) 332 | elif isinstance(obj, Mapping): 333 | return {str(k): serialize(v) for k, v in obj.items()} 334 | elif isinstance(obj, Iterable) and not isinstance(obj, bytes): 335 | return [serialize(v) for v in obj] 336 | else: 337 | raise ValueError(f"Cannot serialize {obj!r}") 338 | 339 | 340 | def to_yaml(model_config: ModelConfig) -> str: 341 | dict_config = serialize(model_config) 342 | old_representers = yaml.SafeDumper.yaml_representers.copy() 343 | 344 | def flow_style_list_representer(dumper, data): 345 | return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True) 346 | 347 | def flow_style_dict_representer(dumper, data): 348 | if "shape" in data and "dtype" in data: 349 | return dumper.represent_mapping('tag:yaml.org,2002:map', data, flow_style=True) 350 | else: 351 | return dumper.represent_mapping('tag:yaml.org,2002:map', data) 352 | 353 | try: 354 | yaml.SafeDumper.add_representer(list, flow_style_list_representer) 355 | yaml.SafeDumper.add_representer(dict, flow_style_dict_representer) 356 | return yaml.safe_dump(dict_config, width=2**64, sort_keys=False) 357 | finally: 358 | yaml.SafeDumper.yaml_representers.clear() 359 | yaml.SafeDumper.yaml_representers.update(old_representers) 360 | 361 | 362 | def from_yaml(yaml_config: str) -> ModelConfig: 363 | dict_config = yaml.load(yaml_config, Loader=YamlLoader) 364 | return ModelConfigImpl(**dict_config) 365 | 366 | 367 | def register(config: ModelConfig): 368 | if config.identifier in _model_configs_registry_base or config.identifier in _model_configs_registry_aux: 369 | raise ValueError(f"Model {config.identifier} already exists") 370 | 371 | _model_configs_registry_base[config.identifier] = config 372 | 373 | 374 | def register_aux(config: ModelConfig): 375 | if config.identifier in _model_configs_registry_base or config.identifier in _model_configs_registry_aux: 376 | raise ValueError(f"Model {config.identifier} already exists") 377 | 378 | _model_configs_registry_aux[config.identifier] = config 379 | 380 | 381 | def resolve(identifier: str) -> ModelConfig: 382 | try: 383 | return _model_configs_registry_base[identifier] 384 | except KeyError: 385 | pass 386 | 387 | try: 388 | return _model_configs_registry_aux[identifier] 389 | except KeyError: 390 | pass 391 | 392 | if identifier == INFER.identifier: 393 | return INFER 394 | 395 | if identifier == "structural": 396 | raise ValueError( 397 | "the 'structural' model config is not a unique object, " 398 | "it needs to be manually instantiated in this way: model_configs.StructuralModelConfig(...)" 399 | ) 400 | 401 | suggestions = fuzzywuzzy.process.extractOne(identifier, _model_configs_registry_base.keys()) 402 | postfix = "" 403 | if suggestions is not None: 404 | postfix = f". Nearest match is '{suggestions[0]}'" 405 | raise ValueError(f"unknown model implementation: {identifier}{postfix}") 406 | 407 | 408 | def get_all() -> List[ModelConfig]: 409 | res = get_all_base() + get_all_aux() 410 | res.sort(key=lambda c: c.identifier, reverse=True) 411 | return res 412 | 413 | 414 | def get_all_base() -> List[ModelConfig]: 415 | return list(_model_configs_registry_base.values()) 416 | 417 | 418 | def get_all_aux() -> List[ModelConfig]: 419 | return list(_model_configs_registry_aux.values()) 420 | -------------------------------------------------------------------------------- /sd_mecha/extensions/model_formats.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import functools 3 | import inspect 4 | import pathlib 5 | import fuzzywuzzy.process 6 | import torch 7 | from typing import Mapping, Protocol, Optional, runtime_checkable, List 8 | from sd_mecha.streaming import InSafetensorsDict, OutSafetensorsDict 9 | from sd_mecha.typing_ import WriteOnlyMapping 10 | 11 | 12 | @runtime_checkable 13 | class ModelFormat(Protocol): 14 | identifier: str 15 | 16 | @abc.abstractmethod 17 | def matches(self, path: pathlib.Path) -> bool: 18 | ... 19 | 20 | @abc.abstractmethod 21 | def get_read_dict(self, path: pathlib.Path, buffer_size: int) -> Mapping[str, torch.Tensor]: 22 | ... 23 | 24 | @abc.abstractmethod 25 | def get_write_dict( 26 | self, 27 | path: pathlib.Path, 28 | model_config, 29 | mecha_recipe: str, 30 | buffer_size: int, 31 | ) -> WriteOnlyMapping[str, torch.Tensor]: 32 | ... 33 | 34 | 35 | def register_model_format( 36 | model_format: Optional[type(ModelFormat)] = None, *, 37 | identifier: Optional[str] = None, 38 | ): 39 | if model_format is None: 40 | return lambda model_format: __register_model_format_impl(model_format, identifier=identifier) 41 | return __register_model_format_impl(model_format, identifier=identifier) 42 | 43 | 44 | def __register_model_format_impl( 45 | model_format: type(ModelFormat), *, 46 | identifier: Optional[str], 47 | ): 48 | if not inspect.isclass(model_format): 49 | raise ValueError(f"model_format must be a class, not {type(ModelFormat)}") 50 | 51 | if identifier is None: 52 | identifier = model_format.__name__ 53 | 54 | if identifier in _model_format_registry: 55 | raise ValueError(f"model format '{identifier}' already exists") 56 | 57 | model_format = model_format() 58 | model_format.identifier = identifier 59 | _model_format_registry[identifier] = model_format 60 | 61 | 62 | def resolve(identifier: str) -> ModelFormat: 63 | try: 64 | return _model_format_registry[identifier] 65 | except KeyError as e: 66 | suggestion = fuzzywuzzy.process.extractOne(str(e), _model_format_registry.keys())[0] 67 | raise ValueError(f"unknown model format: {e}. Nearest match is '{suggestion}'") 68 | 69 | 70 | def get_all() -> List[ModelFormat]: 71 | return list(_model_format_registry.values()) 72 | 73 | 74 | _model_format_registry = {} 75 | 76 | 77 | @functools.cache 78 | def _register_builtin_model_formats(): 79 | @register_model_format(identifier="single_file") 80 | class SingleFileModelFormat(ModelFormat): 81 | def matches(self, path: pathlib.Path) -> bool: 82 | path = path.resolve() 83 | return path.suffix == ".safetensors" 84 | 85 | def get_read_dict(self, path: pathlib.Path, buffer_size: int) -> Mapping[str, torch.Tensor]: 86 | return InSafetensorsDict(path, buffer_size) 87 | 88 | def get_write_dict( 89 | self, 90 | path: pathlib.Path, 91 | model_config, 92 | mecha_recipe: str, 93 | buffer_size: int, 94 | ) -> WriteOnlyMapping[str, torch.Tensor]: 95 | return OutSafetensorsDict(path, model_config.metadata, mecha_recipe, buffer_size) 96 | 97 | 98 | _register_builtin_model_formats() 99 | -------------------------------------------------------------------------------- /sd_mecha/helpers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | import torch 4 | from .extensions.merge_spaces import MergeSpace 5 | from .extensions.model_configs import ModelConfig 6 | from .merging import merge 7 | from .conversion import convert 8 | from .recipe_nodes import ModelRecipeNode, LiteralRecipeNode, RecipeNode, RecipeNodeOrValue 9 | from .extensions.merge_methods import NonDictLiteralValue 10 | from typing import Optional, List, MutableMapping, Iterable, Mapping 11 | 12 | 13 | def model( 14 | state_dict: str | pathlib.Path | Mapping[str, torch.Tensor], 15 | config: Optional[str | ModelConfig] = None, 16 | merge_space: str | MergeSpace = "weight", 17 | ) -> RecipeNode: 18 | """ 19 | Create a recipe node representing a state dict. 20 | 21 | Args: 22 | state_dict: 23 | Path to a `.safetensors` file or an already loaded state dict. 24 | config: 25 | Model config or an identifier thereof. 26 | merge_space: 27 | The merge space in which the model is expected to be. 28 | 29 | Returns: 30 | ModelRecipeNode: A node that can be used in recipe graphs. 31 | """ 32 | if merge_space is None: 33 | raise ValueError("merge space cannot be None") 34 | 35 | if isinstance(state_dict, Mapping): 36 | if state_dict and not isinstance(first_value := next(iter(state_dict.values())), torch.Tensor): 37 | raise ValueError(f"state dict must contain values of type Tensor, not {type(first_value)}") 38 | return LiteralRecipeNode(state_dict, model_config=config, merge_space=merge_space) 39 | if isinstance(state_dict, str): 40 | state_dict = pathlib.Path(state_dict) 41 | return ModelRecipeNode(state_dict, model_config=config, merge_space=merge_space) 42 | 43 | 44 | def literal( 45 | value: NonDictLiteralValue | dict, 46 | config: Optional[ModelConfig | str] = None, 47 | merge_space: Optional[str | MergeSpace] = None, 48 | ) -> LiteralRecipeNode: 49 | """ 50 | Create a recipe node wrapping tensors or some builtin python objects. 51 | 52 | This is used to access recipe node properties on python objects directly, i.e.: 53 | ```python 54 | sd_mecha.literal({...}) | 3 55 | ``` 56 | There is typically no need to wrap inputs into recipe nodes manually as this function is implicitly applied whenever needed. 57 | 58 | Args: 59 | value: 60 | The literal value to wrap into a recipe node. 61 | config: 62 | Model config or an identifier thereof. 63 | merge_space: 64 | The merge space in which the literal is expected to be. 65 | 66 | Returns: 67 | LiteralRecipeNode: A recipe node representing the literal value. 68 | """ 69 | return LiteralRecipeNode(value, model_config=config, merge_space=merge_space) 70 | 71 | 72 | def set_log_level(level: str = "INFO"): 73 | logging.basicConfig(format="%(levelname)s: %(message)s", level=level) 74 | 75 | 76 | class Defaults: 77 | """ 78 | Convenience class for common recipe operations to reduce repetition in recipe scripts. 79 | """ 80 | 81 | def __init__( 82 | self, 83 | model_dirs: pathlib.Path | str | Iterable[pathlib.Path | str] = ..., 84 | merge_device: Optional[str | torch.device] = ..., 85 | merge_dtype: Optional[torch.dtype] = ..., 86 | output_device: Optional[str | torch.device] = ..., 87 | output_dtype: Optional[torch.dtype] = ..., 88 | threads: Optional[int] = ..., 89 | total_buffer_size: int = ..., 90 | strict_weight_space: bool = ..., 91 | check_finite: bool = ..., 92 | omit_extra_keys: bool = ..., 93 | omit_ema: bool = ..., 94 | check_mandatory_keys: bool = ..., 95 | tqdm: type = ..., 96 | ): 97 | """ 98 | Args: 99 | See documentation for `sd_mecha.merge` or `sd_mecha.conversion` for a description of each parameter. 100 | """ 101 | self.__model_dirs = model_dirs 102 | self.__merge_device = merge_device 103 | self.__merge_dtype = merge_dtype 104 | self.__output_device = output_device 105 | self.__output_dtype = output_dtype 106 | self.__threads = threads 107 | self.__total_buffer_size = total_buffer_size 108 | self.__strict_weight_space = strict_weight_space 109 | self.__check_finite = check_finite 110 | self.__omit_extra_keys = omit_extra_keys 111 | self.__omit_ema = omit_ema 112 | self.__check_mandatory_keys = check_mandatory_keys 113 | self.__tqdm = tqdm 114 | 115 | def convert( 116 | self, 117 | recipe: RecipeNodeOrValue, 118 | config: str | ModelConfig | RecipeNode, 119 | model_dirs: pathlib.Path | str | Iterable[pathlib.Path | str] = ..., 120 | ): 121 | """ 122 | Convert a recipe or model from one model config to another. 123 | 124 | See `sd_mecha.convert` for more information. 125 | """ 126 | if model_dirs is ...: 127 | model_dirs = self.__model_dirs 128 | model_dirs = self._model_dirs_to_pathlib_list(model_dirs) 129 | return convert(recipe, config, model_dirs) 130 | 131 | def merge( 132 | self, 133 | recipe: RecipeNodeOrValue, 134 | *, 135 | fallback_model: Optional[RecipeNodeOrValue] = ..., 136 | merge_device: Optional[str | torch.device] = ..., 137 | merge_dtype: Optional[torch.dtype] = ..., 138 | output_device: Optional[str | torch.device] = ..., 139 | output_dtype: Optional[torch.dtype] = ..., 140 | threads: Optional[int] = ..., 141 | total_buffer_size: int = ..., 142 | model_dirs: pathlib.Path | str | Iterable[pathlib.Path | str] = ..., 143 | strict_weight_space: bool = ..., 144 | check_finite: bool = ..., 145 | omit_extra_keys: bool = ..., 146 | omit_ema: bool = ..., 147 | check_mandatory_keys: bool = ..., 148 | tqdm: type = ..., 149 | output: MutableMapping[str, torch.Tensor] | pathlib.Path | str = ..., 150 | ) -> Optional[MutableMapping[str, torch.Tensor]]: 151 | """ 152 | Materialize a state dict from a recipe graph and optionally save it to a file. 153 | 154 | See `sd_mecha.merge` for more information. 155 | """ 156 | if merge_device is ...: 157 | merge_device = self.__merge_device 158 | if merge_dtype is ...: 159 | merge_dtype = self.__merge_dtype 160 | if output_device is ...: 161 | output_device = self.__output_device 162 | if output_dtype is ...: 163 | output_dtype = self.__output_dtype 164 | if threads is ...: 165 | threads = self.__threads 166 | if total_buffer_size is ...: 167 | total_buffer_size = self.__total_buffer_size 168 | if model_dirs is ...: 169 | model_dirs = self.__model_dirs 170 | model_dirs = self._model_dirs_to_pathlib_list(model_dirs) 171 | if strict_weight_space is ...: 172 | strict_weight_space = self.__strict_weight_space 173 | if check_finite is ...: 174 | check_finite = self.__check_finite 175 | if omit_extra_keys is ...: 176 | omit_extra_keys = self.__omit_extra_keys 177 | if omit_ema is ...: 178 | omit_ema = self.__omit_ema 179 | if check_mandatory_keys is ...: 180 | check_mandatory_keys = self.__check_mandatory_keys 181 | 182 | return merge( 183 | recipe, 184 | fallback_model=fallback_model, 185 | merge_device=merge_device, 186 | merge_dtype=merge_dtype, 187 | output_device=output_device, 188 | output_dtype=output_dtype, 189 | threads=threads, 190 | total_buffer_size=total_buffer_size, 191 | model_dirs=model_dirs, 192 | strict_weight_space=strict_weight_space, 193 | check_finite=check_finite, 194 | omit_extra_keys=omit_extra_keys, 195 | omit_ema=omit_ema, 196 | check_mandatory_keys=check_mandatory_keys, 197 | tqdm=tqdm, 198 | output=output, 199 | ) 200 | 201 | @staticmethod 202 | def _model_dirs_to_pathlib_list(models_dir): 203 | if models_dir is ...: 204 | models_dir = [] 205 | if not isinstance(models_dir, List): 206 | models_dir = [models_dir] 207 | models_dir = list(models_dir) 208 | for i in range(len(models_dir)): 209 | if isinstance(models_dir[i], str): 210 | models_dir[i] = pathlib.Path(models_dir[i]) 211 | if models_dir[i] is not None: 212 | models_dir[i] = models_dir[i].absolute() 213 | 214 | return models_dir 215 | -------------------------------------------------------------------------------- /sd_mecha/merge_method_wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Dict 3 | from torch import Tensor 4 | from .extensions.merge_methods import value_to_node, RecipeNodeOrValue, Parameter 5 | from . import recipe_nodes 6 | from sd_mecha.extensions.builtin import merge_methods 7 | from sd_mecha.extensions.builtin.merge_methods import ( 8 | subtract, 9 | ties_sum, 10 | ties_sum_extended, 11 | clamp, 12 | model_stock, 13 | ) 14 | 15 | 16 | def add_difference( 17 | a: RecipeNodeOrValue, b: RecipeNodeOrValue, c: Optional[RecipeNodeOrValue] = None, *, 18 | alpha: Parameter(Tensor) = 1.0, 19 | clamp_to_ab: Parameter(bool) = False, 20 | ) -> recipe_nodes.RecipeNode: 21 | a = value_to_node(a) 22 | b = value_to_node(b) 23 | original_b = b 24 | 25 | if c is not None: 26 | c = value_to_node(c) 27 | b = subtract( 28 | b, c, 29 | ) 30 | 31 | res = merge_methods.add_difference( 32 | a, b, 33 | alpha=alpha, 34 | ) 35 | 36 | if a.merge_space == original_b.merge_space: 37 | b = original_b 38 | 39 | if clamp_to_ab: 40 | if a.merge_space != b.merge_space: 41 | raise TypeError(f"Merge space of A {a.merge_space} and B {b.merge_space} must be the same to clamp the merge.") 42 | res = clamp( 43 | res, a, b, 44 | ) 45 | 46 | return res | a 47 | 48 | 49 | def add_perpendicular( 50 | a: RecipeNodeOrValue, b: RecipeNodeOrValue, c: RecipeNodeOrValue, *, 51 | alpha: Parameter(Tensor) = 1.0, 52 | ) -> recipe_nodes.RecipeNode: 53 | a = value_to_node(a) 54 | b = value_to_node(b) 55 | c = value_to_node(c) 56 | 57 | a_diff = subtract( 58 | a, c, 59 | ) 60 | b_diff = subtract( 61 | b, c, 62 | ) 63 | 64 | perp_diff = merge_methods.perpendicular_component( 65 | a_diff, b_diff, 66 | ) 67 | 68 | return merge_methods.add_difference( 69 | a, perp_diff, 70 | alpha=alpha, 71 | ) 72 | 73 | 74 | # latex notes in reference to original implementation: https://arxiv.org/abs/2306.01708 75 | # - `base`: $$ \theta_{init} $$ 76 | # - `*models`: $$ \{\theta_{init}\}_{t=1}^n $$ 77 | # - `models` after `subtract`: $$ \tau_t $$ 78 | # - `alpha`: $$ \lambda $$ 79 | # - `k`: $$ k $$ ( From $$ \% $$ to $$ 1 $$ ) 80 | # - `res`: $$ \lambda * \tau_m $$ 81 | # - `return`: $$ \theta_m $$ 82 | # Special mode "TIES-SOUP" has been implemented by setting `vote_sgn` > 0.0 83 | # Special mode "TIES-STOCK" has been implemented by setting `apply_stock` > 0.0 84 | def add_difference_ties( 85 | base: RecipeNodeOrValue, 86 | *models: RecipeNodeOrValue, 87 | alpha: Parameter(Tensor) = 1.0, 88 | k: Parameter(float) = 1.0, 89 | ) -> recipe_nodes.RecipeNode: 90 | # $$ \{\theta_{init}\}_{t=1}^n $$ 91 | base = value_to_node(base) 92 | models = tuple(value_to_node(model) for model in models) 93 | 94 | # Create task vectors. 95 | # $$ \tau_t $$ 96 | models = tuple( 97 | subtract(model, base) 98 | if model.merge_space == "weight" else 99 | model 100 | for model in models 101 | ) 102 | 103 | # step 1 + step 2 + step 3 104 | res = ties_sum( 105 | *models, 106 | k=k, 107 | ) 108 | 109 | # Obtain merged checkpoint 110 | 111 | # $$ \theta_{init} + \lambda * \tau_m $$ 112 | return add_difference( 113 | base, res, 114 | alpha=alpha, 115 | ) 116 | 117 | 118 | def add_difference_ties_extended( 119 | base: RecipeNodeOrValue, 120 | *models: RecipeNodeOrValue, 121 | alpha: Parameter(Tensor) = 1.0, 122 | k: Parameter(float) = 0.2, 123 | vote_sgn: Parameter(bool) = False, 124 | apply_stock: Parameter(bool) = False, 125 | cos_eps: Parameter(float) = 1e-6, 126 | apply_median: Parameter(bool) = False, 127 | eps: Parameter(float) = 1e-6, 128 | maxiter: Parameter(int) = 100, 129 | ftol: Parameter(float) = 1e-20, 130 | ) -> recipe_nodes.RecipeNode: 131 | # $$ \{\theta_{init}\}_{t=1}^n $$ 132 | base = value_to_node(base) 133 | models = tuple(value_to_node(model) for model in models) 134 | 135 | # Create task vectors. 136 | # $$ \tau_t $$ 137 | models = tuple( 138 | subtract(model, base) 139 | if model.merge_space == "weight" else 140 | model 141 | for model in models 142 | ) 143 | 144 | # step 1 + step 2 + step 3 145 | res = ties_sum_extended( 146 | *models, 147 | k=k, 148 | vote_sgn=vote_sgn, 149 | apply_stock=apply_stock, 150 | cos_eps=cos_eps, 151 | apply_median=apply_median, 152 | eps=eps, 153 | maxiter=maxiter, 154 | ftol=ftol, 155 | ) 156 | 157 | # Obtain merged checkpoint 158 | 159 | # $$ \theta_{init} + \lambda * \tau_m $$ 160 | return add_difference( 161 | base, res, 162 | alpha=alpha, 163 | ) 164 | 165 | 166 | def copy_region( 167 | a: RecipeNodeOrValue, b: RecipeNodeOrValue, c: Optional[RecipeNodeOrValue] = None, *, 168 | width: Parameter(float) = 1.0, 169 | offset: Parameter(float) = 0.0, 170 | top_k: Parameter(bool) = False, 171 | ) -> recipe_nodes.RecipeNode: 172 | a = value_to_node(a) 173 | b = value_to_node(b) 174 | 175 | if c is not None: 176 | c = value_to_node(c) 177 | 178 | a = subtract( 179 | a, c, 180 | ) 181 | b = subtract( 182 | b, c, 183 | ) 184 | 185 | copy_method = [merge_methods.tensor_sum, merge_methods.top_k_tensor_sum][int(top_k)] 186 | res = copy_method( 187 | a, b, 188 | width=width, 189 | offset=offset, 190 | ) 191 | 192 | if c is not None: 193 | res = merge_methods.add_difference( 194 | c, res, 195 | alpha=1.0, 196 | ) 197 | 198 | return res 199 | 200 | 201 | tensor_sum = copy_region 202 | 203 | 204 | def rotate( 205 | a: RecipeNodeOrValue, b: RecipeNodeOrValue, c: Optional[RecipeNodeOrValue] = None, *, 206 | alignment: Parameter(float) = 1.0, 207 | alpha: Parameter(float) = 0.0, 208 | cache: Optional[Dict[str, torch.Tensor]] = None, 209 | ) -> recipe_nodes.RecipeNode: 210 | a = value_to_node(a) 211 | b = value_to_node(b) 212 | 213 | if c is not None: 214 | c = value_to_node(c) 215 | 216 | a = subtract( 217 | a, c, 218 | ) 219 | b = subtract( 220 | b, c, 221 | ) 222 | 223 | res = merge_methods.rotate( 224 | a, b, 225 | alignment=alignment, 226 | alpha=alpha, 227 | ).set_cache(cache) 228 | 229 | if c is not None: 230 | res = merge_methods.add_difference( 231 | c, res, 232 | alpha=1.0, 233 | ) 234 | 235 | return res 236 | 237 | 238 | def dropout( 239 | a: RecipeNodeOrValue, 240 | *models: RecipeNodeOrValue, 241 | probability: Parameter(float) = 0.9, 242 | alpha: Parameter(Tensor) = 0.5, 243 | overlap: Parameter(float) = 1.0, 244 | overlap_emphasis: Parameter(float) = 0.0, 245 | seed: Parameter(int) = None, 246 | ) -> recipe_nodes.RecipeNode: 247 | deltas = [ 248 | subtract(model, a) 249 | for model in models 250 | ] 251 | ba_delta = merge_methods.dropout(*deltas, probability=probability, overlap=overlap, overlap_skew=overlap_emphasis, seed=seed) 252 | return merge_methods.add_difference(a, ba_delta, alpha=alpha) 253 | 254 | 255 | ties_sum_with_dropout = merge_methods.ties_sum_with_dropout 256 | 257 | 258 | # latex notes in reference to original implementation: https://arxiv.org/abs/2311.03099 259 | # Notice that this is "TIES Merging w/ DARE", which is "Prune > Merge (TIES) > Rescale" 260 | # See https://slgero.medium.com/merge-large-language-models-29897aeb1d1a for details 261 | # - `base`: $$ \theta_{PRE} $$ 262 | # - `*models`: $$ \theta_{SFT}^{t_k} $$ 263 | # - `deltas`: $$ \delta^t = \theta_{SFT}^{t} - \theta_{PRE} \in \mathbb{R}^d $$ 264 | # - `probability`: $$ p $$ 265 | # - `res`: $$ \hat{\delta}^t = \tilde{\delta}^t / (1-p) $$ 266 | # - `alpha`: $$ \lambda $$ 267 | # - `k`: $$ k $$ ( From $$ \% $$ to $$ 1 $$ ) in TIES paper 268 | # - `return`: $$ \theta_M = \theta_{PRE} + \lambda \cdot \Sigma_{k=1}^{K} \tilde{\delta}^{t_k} $$ 269 | # Special mode "TIES-SOUP" has been implemented by setting `vote_sgn` > 0.0 270 | def ties_with_dare( 271 | base: RecipeNodeOrValue, 272 | *models: RecipeNodeOrValue, 273 | probability: Parameter(float) = 0.9, 274 | rescale: Parameter(bool) = True, 275 | alpha: Parameter(Tensor) = 1.0, 276 | seed: Parameter(int) = None, 277 | k: Parameter(float) = 1.0, 278 | vote_sgn: Parameter(bool) = False, 279 | apply_stock: Parameter(bool) = False, 280 | cos_eps: Parameter(float) = 1e-6, 281 | apply_median: Parameter(bool) = False, 282 | eps: Parameter(float) = 1e-6, 283 | maxiter: Parameter(int) = 100, 284 | ftol: Parameter(float) = 1e-20, 285 | ) -> recipe_nodes.RecipeNode: 286 | # $$ \delta^t = \theta_{SFT}^{t} - \theta_{PRE} \in \mathbb{R}^d $$ 287 | base = value_to_node(base) 288 | models = tuple(value_to_node(model) for model in models) 289 | deltas = tuple( 290 | subtract(model, base) 291 | if model.merge_space == "weight" else 292 | model 293 | for model in models 294 | ) 295 | 296 | # $$ \tilde{\delta}^{t_k} $$ 297 | res = ties_sum_with_dropout( 298 | *deltas, 299 | probability=probability, 300 | rescale=rescale, 301 | k=k, 302 | vote_sgn=vote_sgn, 303 | seed=seed, 304 | apply_stock=apply_stock, 305 | cos_eps=cos_eps, 306 | apply_median=apply_median, 307 | eps=eps, 308 | maxiter=maxiter, 309 | ftol=ftol, 310 | ) 311 | 312 | # $$ \theta_M = \theta_{PRE} + \lambda \cdot \Sigma_{k=1}^{K} \tilde{\delta}^{t_k} $$ 313 | return merge_methods.add_difference(base, res, alpha=alpha) 314 | 315 | 316 | # Following mergekit's implementation of Model Stock (which official implementation doesn't exist) 317 | # https://github.com/arcee-ai/mergekit/blob/main/mergekit/merge_methods/model_stock.py 318 | def n_model_stock( 319 | base: RecipeNodeOrValue, 320 | *models: RecipeNodeOrValue, 321 | cos_eps: Parameter(float) = 1e-6, 322 | ) -> recipe_nodes.RecipeNode: 323 | base = value_to_node(base) 324 | models = tuple(value_to_node(model) for model in models) 325 | deltas = tuple( 326 | subtract(model, base) 327 | if model.merge_space == "weight" else 328 | model 329 | for model in models 330 | ) 331 | 332 | # This is hacky: Both w_avg and w_h will be calculated there. 333 | # Notice that t and cos_theta is vector instead of single value. 334 | # Conceptually it could compatable with TIES, but algorithm should be rewritten. 335 | res = model_stock( 336 | *deltas, 337 | cos_eps=cos_eps, 338 | ) 339 | 340 | return merge_methods.add_difference(base, res, alpha=1.0) 341 | -------------------------------------------------------------------------------- /sd_mecha/recipe_nodes.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import itertools 3 | import pathlib 4 | import torch 5 | from .extensions import model_configs, merge_methods, merge_spaces 6 | from .extensions.merge_spaces import MergeSpace 7 | from typing import Optional, Dict, Tuple, Union 8 | 9 | 10 | class RecipeNode(abc.ABC): 11 | @abc.abstractmethod 12 | def accept(self, visitor, *args, **kwargs): 13 | pass 14 | 15 | @property 16 | @abc.abstractmethod 17 | def merge_space(self) -> merge_spaces.MergeSpace: 18 | pass 19 | 20 | @property 21 | @abc.abstractmethod 22 | def model_config(self) -> Optional[model_configs.ModelConfig]: 23 | pass 24 | 25 | @abc.abstractmethod 26 | def __contains__(self, item): 27 | pass 28 | 29 | def set_cache(self, _cache: dict = ...): 30 | return self 31 | 32 | def __add__(self, other: "RecipeNodeOrValue") -> "RecipeNode": 33 | other = merge_methods.value_to_node(other) 34 | base, delta = self, other 35 | if other.merge_space == "weight": 36 | base, delta = other, self 37 | return merge_methods.resolve("add_difference")(base, delta) 38 | 39 | def __radd__(self, other: "RecipeNodeOrValue") -> "RecipeNode": 40 | other = merge_methods.value_to_node(other) 41 | return other + self 42 | 43 | def __sub__(self, other: "RecipeNodeOrValue") -> "RecipeNode": 44 | other = merge_methods.value_to_node(other) 45 | return merge_methods.resolve("subtract")(self, other) 46 | 47 | def __rsub__(self, other: "RecipeNodeOrValue") -> "RecipeNode": 48 | other = merge_methods.value_to_node(other) 49 | return other - self 50 | 51 | def __or__(self, other: "RecipeNodeOrValue") -> "RecipeNode": 52 | other = merge_methods.value_to_node(other) 53 | return merge_methods.resolve("fallback")(self, other) 54 | 55 | def __ror__(self, other: "RecipeNodeOrValue") -> "RecipeNode": 56 | other = merge_methods.value_to_node(other) 57 | return other | self 58 | 59 | def to(self, *, device: Optional[Union[str, torch.device, "RecipeNode"]] = None, dtype: Optional[Union[str, torch.dtype, "RecipeNode"]] = None): 60 | if isinstance(device, torch.device): 61 | device = str(device) 62 | if isinstance(dtype, torch.dtype): 63 | from sd_mecha.extensions.builtin.merge_methods import cast_dtype_map_reversed 64 | dtype = cast_dtype_map_reversed[dtype] 65 | return merge_methods.resolve("cast")(self, device=device, dtype=dtype) 66 | 67 | 68 | PythonLiteralValue = str | int | float | bool | type(None) 69 | NonDictLiteralValue = PythonLiteralValue | torch.Tensor 70 | LiteralValue = NonDictLiteralValue | dict 71 | RecipeNodeOrValue = RecipeNode | LiteralValue | pathlib.Path 72 | 73 | 74 | class LiteralRecipeNode(RecipeNode): 75 | def __init__( 76 | self, 77 | value: LiteralValue, 78 | *, 79 | model_config: Optional[str | model_configs.ModelConfig] = None, 80 | merge_space: Optional[str | merge_spaces.MergeSpace] = None, 81 | ): 82 | self.value = value 83 | self.__model_config = model_configs.resolve(model_config) if isinstance(model_config, str) else model_config 84 | self.__merge_space = merge_spaces.resolve(merge_space) if isinstance(merge_space, str) else merge_space 85 | if isinstance(self.value, dict): 86 | first_value = next(iter((*self.value.values(), Ellipsis))) # using ... as placeholder as None is a valid LiteralValue 87 | if isinstance(first_value, RecipeNode): 88 | if model_config is None: 89 | self.__model_config = first_value.model_config 90 | if merge_space is None: 91 | self.__merge_space = first_value.merge_space 92 | if not all(v.model_config == self.__model_config for v in self.value.values()): 93 | raise ValueError(f"All model configs should be the same, expected {self.__model_config} but got {set(v.model_config for v in self.value.values())}") 94 | if not all(v.merge_space == first_value.merge_space for v in self.value.values()): 95 | raise ValueError(f"All merge spaces should be the same, but got multiple: {set(v.merge_space for v in self.value.values())}") 96 | if self.__merge_space is None: 97 | self.__merge_space = merge_spaces.resolve("param") 98 | 99 | def accept(self, visitor, *args, **kwargs): 100 | return visitor.visit_literal(self, *args, **kwargs) 101 | 102 | @property 103 | def merge_space(self) -> MergeSpace: 104 | return self.__merge_space 105 | 106 | @property 107 | def model_config(self) -> Optional[model_configs.ModelConfig]: 108 | return self.__model_config 109 | 110 | @model_config.setter 111 | def model_config(self, model_config: Optional[model_configs.ModelConfig]): 112 | self.__model_config = model_config 113 | 114 | def __contains__(self, item): 115 | if isinstance(item, LiteralRecipeNode): 116 | return self.value == item.value 117 | else: 118 | return False 119 | 120 | 121 | class ModelRecipeNode(RecipeNode): 122 | def __init__( 123 | self, 124 | path: pathlib.Path, 125 | *, 126 | model_config: Optional[str | model_configs.ModelConfig] = None, 127 | merge_space: str | MergeSpace = "weight", 128 | ): 129 | if not isinstance(path, pathlib.Path): 130 | raise TypeError(f"The type of 'state_dict' must be Path, not {type(path).__name__}") 131 | 132 | self.path = path 133 | self.state_dict = None 134 | self.__model_config = model_configs.resolve(model_config) if isinstance(model_config, str) else model_config 135 | self.__merge_space = merge_spaces.resolve(merge_space) if isinstance(merge_space, str) else merge_space 136 | 137 | def accept(self, visitor, *args, **kwargs): 138 | return visitor.visit_model(self, *args, **kwargs) 139 | 140 | @property 141 | def merge_space(self) -> merge_spaces.MergeSpace: 142 | return self.__merge_space 143 | 144 | @property 145 | def model_config(self) -> Optional[model_configs.ModelConfig]: 146 | return self.__model_config 147 | 148 | @model_config.setter 149 | def model_config(self, value: Optional[model_configs.ModelConfig]): 150 | self.__model_config = value 151 | 152 | def __contains__(self, item): 153 | if isinstance(item, ModelRecipeNode): 154 | return self.path == item.path 155 | else: 156 | return False 157 | 158 | 159 | class MergeRecipeNode(RecipeNode): 160 | def __init__( 161 | self, 162 | merge_method, 163 | args: Tuple[RecipeNode, ...], 164 | kwargs: Dict[str, RecipeNode], 165 | cache: dict = None, 166 | ): 167 | self.merge_method = merge_method 168 | self.args = args 169 | self.kwargs = kwargs 170 | self.cache = cache 171 | self.__validate_args() 172 | 173 | def __validate_args(self): 174 | if not isinstance(self.merge_space, MergeSpace): 175 | raise RuntimeError(f"Could not infer merge space from arguments for method {self.merge_method.identifier}") 176 | 177 | def accept(self, visitor, *args, **kwargs): 178 | return visitor.visit_merge(self, *args, **kwargs) 179 | 180 | @property 181 | def merge_space(self) -> merge_spaces.MergeSpace: 182 | return self.merge_method.get_return_merge_space( 183 | [v.merge_space for v in self.args], 184 | {k: v.merge_space for k, v in self.kwargs.items()}, 185 | ) 186 | 187 | @property 188 | def model_config(self) -> Optional[model_configs.ModelConfig]: 189 | return self.merge_method.get_return_config( 190 | [v.model_config for v in self.args], 191 | {k: v.model_config for k, v in self.kwargs.items()}, 192 | ) 193 | 194 | def __contains__(self, item): 195 | return self is item or any( 196 | item in v 197 | for v in itertools.chain(self.args, self.kwargs.values()) 198 | if isinstance(v, RecipeNode) 199 | ) 200 | 201 | def set_cache(self, cache: dict = ...): 202 | if cache is Ellipsis: 203 | cache = {} 204 | 205 | self.cache = cache 206 | return self 207 | 208 | 209 | class RecipeVisitor(abc.ABC): 210 | @abc.abstractmethod 211 | def visit_literal(self, node: LiteralRecipeNode): 212 | pass 213 | 214 | @abc.abstractmethod 215 | def visit_model(self, node: ModelRecipeNode): 216 | pass 217 | 218 | @abc.abstractmethod 219 | def visit_merge(self, node: MergeRecipeNode): 220 | pass 221 | 222 | 223 | class ModelDepthRecipeVisitor(RecipeVisitor): 224 | def visit_literal(self, node: LiteralRecipeNode): 225 | return 0 226 | 227 | def visit_model(self, _node: ModelRecipeNode): 228 | return 1 229 | 230 | def visit_merge(self, node: MergeRecipeNode): 231 | return max( 232 | child.accept(self) 233 | for children in (node.args, node.kwargs.values()) 234 | for child in children 235 | ) + 1 236 | 237 | 238 | class ModelsCountVisitor(RecipeVisitor): 239 | def __init__(self): 240 | self.__seen_nodes = [] 241 | 242 | def visit_literal(self, node: LiteralRecipeNode) -> int: 243 | return 0 244 | 245 | def visit_model(self, node: ModelRecipeNode) -> int: 246 | seen = node in self.__seen_nodes 247 | self.__seen_nodes.append(node) 248 | return int(not seen) 249 | 250 | def visit_merge(self, node: MergeRecipeNode) -> int: 251 | return sum( 252 | child.accept(self) 253 | for children in (node.args, node.kwargs.values()) 254 | for child in children 255 | ) 256 | -------------------------------------------------------------------------------- /sd_mecha/serialization.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | from typing import List, Optional, Hashable 4 | from .extensions import merge_methods 5 | from .recipe_nodes import RecipeNode, ModelRecipeNode, RecipeVisitor, LiteralRecipeNode, MergeRecipeNode 6 | 7 | 8 | MECHA_FORMAT_VERSION = "0.1.0" 9 | 10 | 11 | def deserialize_path(recipe: pathlib.Path) -> RecipeNode: 12 | if not recipe.exists(): 13 | raise ValueError(f"unable to deserialize '{recipe}': no such file") 14 | 15 | if recipe.suffix == ".mecha": 16 | with open(recipe, "r") as recipe_file: 17 | return deserialize(recipe_file.read()) 18 | else: 19 | raise ValueError(f"unable to deserialize '{recipe}': unknown extension") 20 | 21 | 22 | def deserialize(recipe: str | List[str]) -> RecipeNode: 23 | """ 24 | Recreate a recipe graph from its serialized `.mecha` format. 25 | 26 | Args: 27 | recipe (str or List[str]): 28 | The textual representation (as a string or list of lines) of the recipe. 29 | 30 | Returns: 31 | A `RecipeNode` that can be further merged or manipulated. 32 | """ 33 | if not isinstance(recipe, list): 34 | recipe = recipe.split("\n") 35 | 36 | if not recipe[0].startswith("version"): 37 | raise RuntimeError("bad format: expected version at line 1") 38 | 39 | actual_version, recipe = recipe[0], recipe[1:] 40 | expected_version = get_version_header(MECHA_FORMAT_VERSION) 41 | if actual_version != expected_version: 42 | raise RuntimeError(f"bad recipe version: got {actual_version}, expected {expected_version}") 43 | 44 | results = [] 45 | 46 | def parse(line): 47 | line = line.strip() 48 | if line.startswith("#"): 49 | return 50 | 51 | command, *args = tokenize(line) 52 | positional_args, keyword_args = preprocess_args(args) 53 | if command == "dict": 54 | results.append(dict(*positional_args, **keyword_args)) 55 | elif command == "literal": 56 | results.append(LiteralRecipeNode(*positional_args, **keyword_args)) 57 | elif command == "model": 58 | path = pathlib.Path(positional_args[0]) 59 | results.append(ModelRecipeNode(path, *positional_args[1:], **keyword_args)) 60 | elif command == "merge": 61 | method, *positional_args = positional_args 62 | method = merge_methods.resolve(method) 63 | results.append(method(*positional_args, **keyword_args)) 64 | else: 65 | raise ValueError(f"unknown command: {command}") 66 | 67 | def preprocess_args(args): 68 | positional_args = [] 69 | named_args = {} 70 | for arg_index, arg in enumerate(args): 71 | if '=' in arg: 72 | key, value = arg.split('=', maxsplit=1) 73 | named_args[key] = get_arg_value(value, arg_index) 74 | else: 75 | positional_args.append(get_arg_value(arg, arg_index)) 76 | return positional_args, named_args 77 | 78 | def get_arg_value(arg, arg_index): 79 | try: 80 | if arg in CONSTANTS: 81 | return CONSTANTS[arg] 82 | elif arg.startswith('&'): 83 | ref_index = int(arg[1:]) 84 | if ref_index < 0 or ref_index >= len(results): 85 | raise ValueError(f"reference {arg} out of bounds") 86 | return results[ref_index] 87 | elif arg.startswith('"') and arg.endswith('"'): 88 | return arg[1:-1] 89 | elif '.' in arg or 'e' in arg.lower(): 90 | return float(arg) 91 | else: 92 | return int(arg) 93 | except ValueError as e: 94 | raise ValueError(f"argument {arg_index}: {str(e)}") 95 | 96 | def tokenize(line): 97 | tokens = [] 98 | current_token = [] 99 | quote_prefix = [] 100 | inside_quotes = False 101 | is_escape = False 102 | for char in line: 103 | if is_escape: 104 | is_escape = False 105 | elif char == "\\": 106 | is_escape = True 107 | continue 108 | elif char == '"': 109 | inside_quotes = not inside_quotes 110 | if inside_quotes: # Begin of quoted string 111 | quote_prefix = current_token 112 | current_token = [] 113 | else: # End of quoted string 114 | tokens.append(f'{"".join(quote_prefix)}"{"".join(current_token)}"') 115 | current_token = [] 116 | quote_prefix = [] 117 | continue 118 | elif char == ' ' and not inside_quotes: 119 | if current_token: # End of a token 120 | tokens.append(''.join(current_token)) 121 | current_token = [] 122 | continue 123 | current_token.append(char) 124 | if inside_quotes: # Handle mismatched quotes 125 | raise ValueError(f"mismatched quotes in input") 126 | if current_token: # Add last token if exists 127 | tokens.append(''.join(current_token)) 128 | return tokens 129 | 130 | for line_num, line in enumerate(recipe, 1): 131 | try: 132 | parse(line) 133 | except ValueError as e: 134 | raise ValueError(f"line {line_num}: {e}.\n {line}") 135 | 136 | return results[-1] 137 | 138 | 139 | def serialize(recipe: RecipeNode, *, output: Optional[pathlib.Path | str] = None) -> str: 140 | """ 141 | Convert a recipe graph into a string that captures its merge instructions. 142 | 143 | This is the first step of persisting a recipe to disk in `.mecha` format. 144 | 145 | Args: 146 | recipe: 147 | A `RecipeNode` describing the merge. 148 | output: 149 | Path to the output file to save. 150 | 151 | Returns: 152 | A string representation of the recipe, suitable for writing to a .mecha file. 153 | """ 154 | serializer = SerializerVisitor() 155 | recipe.accept(serializer) 156 | version_header = get_version_header(MECHA_FORMAT_VERSION) 157 | serialized = "\n".join([version_header] + serializer.instructions) 158 | 159 | if isinstance(output, str): 160 | output = pathlib.Path(output) 161 | if output is not None: 162 | output = output.absolute() 163 | logging.info(f"Saving recipe to {output}") 164 | output.write_text(serialized) 165 | 166 | return serialized 167 | 168 | 169 | def get_version_header(version: str): 170 | return f"version {version}" 171 | 172 | 173 | class SerializerVisitor(RecipeVisitor): 174 | def __init__(self, instructions: Optional[List[str]] = None): 175 | self.instructions = instructions if instructions is not None else [] 176 | 177 | def visit_literal(self, node: LiteralRecipeNode): 178 | value = self.__serialize_value(node.value) 179 | if node.model_config is None: 180 | return value 181 | else: 182 | config = self.__serialize_value(node.model_config.identifier) 183 | merge_space = self.__serialize_value(node.merge_space.identifier) 184 | line = f"literal {value} model_config={config} merge_space={merge_space}" 185 | return self.__add_instruction(line) 186 | 187 | def visit_model(self, node: ModelRecipeNode) -> str: 188 | path = self.__serialize_value(str(node.path)) 189 | config = self.__serialize_value(getattr(node.model_config, "identifier", None)) 190 | merge_space = self.__serialize_value(node.merge_space.identifier) 191 | line = f"model {path} model_config={config} merge_space={merge_space}" 192 | return self.__add_instruction(line) 193 | 194 | def visit_merge(self, node: MergeRecipeNode) -> str: 195 | identifier = self.__serialize_value(node.merge_method.get_identifier()) 196 | parts = ["merge", identifier] + [ 197 | self.__serialize_value(v) 198 | for v in node.args 199 | ] + [ 200 | f"{k}={self.__serialize_value(v)}" 201 | for k, v in node.kwargs.items() 202 | ] 203 | line = " ".join(parts) 204 | return self.__add_instruction(line) 205 | 206 | def __serialize_value(self, value) -> str: 207 | if isinstance(value, str): 208 | value = value.replace("\\", "\\\\").replace('"', "\\\"") 209 | return f'"{value}"' 210 | if isinstance(value, dict): 211 | dict_line = "dict " + " ".join(f"{k}={self.__serialize_value(v)}" for k, v in value.items()) 212 | return self.__add_instruction(dict_line) 213 | if isinstance(value, (int, float)) and not isinstance(value, bool): 214 | return str(value) 215 | # int or float needs to be handled before this (1.0 == True) 216 | if isinstance(value, Hashable) and value in REVERSE_CONSTANTS: 217 | return REVERSE_CONSTANTS[value] 218 | if isinstance(value, RecipeNode): 219 | return value.accept(self) 220 | raise TypeError(f"type {type(value)} cannot be serialized: {value}") 221 | 222 | def __add_instruction(self, instruction: str) -> str: 223 | try: 224 | return f"&{self.instructions.index(instruction)}" 225 | except ValueError: 226 | self.instructions.append(instruction) 227 | return f"&{len(self.instructions) - 1}" 228 | 229 | 230 | CONSTANTS = { 231 | "null": None, 232 | "true": True, 233 | "false": False, 234 | } 235 | REVERSE_CONSTANTS = { 236 | v: k for k, v in CONSTANTS.items() 237 | } 238 | -------------------------------------------------------------------------------- /sd_mecha/streaming.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import contextlib 3 | import ctypes 4 | import dataclasses 5 | import json 6 | import pathlib 7 | import struct 8 | import sys 9 | import threading 10 | import numpy 11 | import torch 12 | import warnings 13 | from collections import OrderedDict 14 | from typing import Optional, Mapping, Iterator, Iterable, Tuple 15 | from tqdm import tqdm 16 | from .typing_ import WriteOnlyMapping 17 | 18 | 19 | @dataclasses.dataclass 20 | class TensorMetadata: 21 | shape: Optional[torch.Size] 22 | dtype: Optional[torch.dtype] 23 | 24 | def __post_init__(self): 25 | if isinstance(self.shape, list): 26 | self.shape = torch.Size(self.shape) 27 | if isinstance(self.dtype, str): 28 | self.dtype = getattr(torch, self.dtype) 29 | 30 | def safetensors_header_value(self, data_offset: int): 31 | if self.shape is None: 32 | raise RuntimeError("invalid operation: metadata doesn't have shape") 33 | 34 | if self.dtype is None: 35 | raise RuntimeError("invalid operation: metadata doesn't have dtype") 36 | 37 | return { 38 | "shape": list(self.shape), 39 | "dtype": DTYPE_REVERSE_MAPPING[self.dtype][0], 40 | "data_offsets": [data_offset, data_offset + self.get_byte_size()] 41 | } 42 | 43 | def get_byte_size(self) -> int: 44 | return self.numel() * self.get_dtype_size() 45 | 46 | def get_dtype_size(self) -> int: 47 | if self.dtype is None: 48 | raise RuntimeError("invalid operation: metadata doesn't have dtype") 49 | 50 | return DTYPE_REVERSE_MAPPING[self.dtype][1] 51 | 52 | def numel(self) -> int: 53 | if self.shape is None: 54 | raise RuntimeError("invalid operation: metadata doesn't have shape") 55 | 56 | return self.shape.numel() 57 | 58 | 59 | class SafetensorsMapping(Mapping[str, torch.Tensor], abc.ABC): 60 | @abc.abstractmethod 61 | def keys(self) -> Iterable[str]: 62 | ... 63 | 64 | @abc.abstractmethod 65 | def metadata(self) -> Iterable[Tuple[str, TensorMetadata]]: 66 | ... 67 | 68 | @abc.abstractmethod 69 | def values(self) -> Iterable[torch.Tensor]: 70 | ... 71 | 72 | @abc.abstractmethod 73 | def items(self) -> Iterable[Tuple[str, torch.Tensor]]: 74 | ... 75 | 76 | 77 | class InSafetensorsDict(SafetensorsMapping): 78 | def __init__(self, file_path: pathlib.Path, buffer_size): 79 | if not file_path.suffix == ".safetensors": 80 | raise ValueError(f"Model type not supported: {file_path} (only safetensors are supported)") 81 | 82 | self.default_buffer_size = buffer_size 83 | self.file = open(file_path, mode='rb', buffering=0) 84 | self.file_path = file_path 85 | self.header_size, self.header = self._read_header() 86 | self.buffer = bytearray() 87 | self.buffer_start_offset = 8 + self.header_size 88 | self.lock = threading.Lock() 89 | 90 | def __del__(self): 91 | self.close() 92 | 93 | def __getitem__(self, key: str) -> torch.Tensor: 94 | if key not in self.header or key == "__metadata__": 95 | raise StateDictKeyError(key) 96 | return self._load_tensor(key) 97 | 98 | def __iter__(self) -> Iterator[str]: 99 | return iter(self.keys()) 100 | 101 | def __len__(self) -> int: 102 | return len(self.header) - int("__metadata__" in self.header) 103 | 104 | def close(self): 105 | if getattr(self, "file", None) is not None: 106 | self.file.close() 107 | self.file = None 108 | self.buffer = None 109 | self.header = None 110 | 111 | def keys(self) -> Iterable[str]: 112 | return ( 113 | key 114 | for key in self.header.keys() 115 | if key != "__metadata__" 116 | ) 117 | 118 | def metadata(self) -> Iterable[Tuple[str, TensorMetadata]]: 119 | for key in self.keys(): 120 | yield key, TensorMetadata(self.header[key]["shape"], DTYPE_MAPPING[self.header[key]["dtype"]][0]) 121 | 122 | def values(self) -> Iterable[torch.Tensor]: 123 | for key in self.keys(): 124 | yield self[key] 125 | 126 | def items(self) -> Iterable[Tuple[str, torch.Tensor]]: 127 | for key in self.keys(): 128 | yield key, self[key] 129 | 130 | def _read_header(self): 131 | header_size_bytes = self.file.read(8) 132 | header_size = struct.unpack(' self.buffer_start_offset + len(self.buffer): 142 | self.file.seek(start_pos) 143 | necessary_buffer_size = max(self.default_buffer_size, length) 144 | if len(self.buffer) < necessary_buffer_size: 145 | self.buffer = bytearray(necessary_buffer_size) 146 | else: 147 | self.buffer = self.buffer[:necessary_buffer_size] 148 | 149 | self.file.readinto(self.buffer) 150 | self.buffer_start_offset = start_pos 151 | 152 | def _load_tensor(self, tensor_name): 153 | tensor_info = self.header[tensor_name] 154 | offsets = tensor_info['data_offsets'] 155 | dtype, dtype_bytes = DTYPE_MAPPING[tensor_info['dtype']] 156 | shape = tensor_info['shape'] 157 | total_bytes = offsets[1] - offsets[0] 158 | if total_bytes == 0: 159 | return torch.tensor([], dtype=dtype).reshape(shape) 160 | 161 | absolute_start_pos = 8 + self.header_size + offsets[0] 162 | with warnings.catch_warnings(): 163 | warnings.simplefilter("ignore") 164 | with self.lock: 165 | self._ensure_buffer(absolute_start_pos, total_bytes) 166 | buffer_offset = absolute_start_pos - self.buffer_start_offset 167 | return torch.frombuffer(self.buffer, count=total_bytes // dtype_bytes, offset=buffer_offset, dtype=dtype).reshape(shape) 168 | 169 | 170 | class StateDictKeyError(KeyError): 171 | """ 172 | Exception raised when a requested key is missing from a streamed or in-memory state dict. 173 | 174 | It behaves like a normal `KeyError`, but is specialized for reporting missing keys 175 | within streaming merges or recipes. 176 | """ 177 | 178 | 179 | @dataclasses.dataclass 180 | class OutSafetensorsDictThreadState: 181 | buffer: bytearray 182 | memory_used: int = dataclasses.field(default=0) 183 | sub_header: dict = dataclasses.field(default_factory=dict) 184 | 185 | 186 | class OutSafetensorsDict(WriteOnlyMapping[str, torch.Tensor]): 187 | def __init__( 188 | self, 189 | file_path: pathlib.Path, 190 | header: Mapping[str, TensorMetadata], 191 | mecha_recipe: Optional[str], 192 | minimum_buffer_size: int, 193 | ): 194 | self.thread_states = {} 195 | self.lock = threading.Lock() 196 | 197 | self.header = { 198 | "__metadata__": {"mecha_recipe": mecha_recipe} if mecha_recipe is not None else {} 199 | } 200 | self.file = file_path.open("wb", buffering=0) 201 | self.file_path = file_path 202 | self.flushed_size = 0 203 | self.minimum_buffer_size = minimum_buffer_size 204 | 205 | self.max_header_size = self._init_buffer(header) 206 | 207 | def __del__(self): 208 | self.file.close() 209 | self.thread_states = None 210 | self.header = None 211 | 212 | def __setitem__(self, key: str, tensor: torch.Tensor) -> None: 213 | tid = threading.current_thread().ident 214 | if tid not in self.thread_states: 215 | self.thread_states[tid] = OutSafetensorsDictThreadState(bytearray(self.minimum_buffer_size)) 216 | 217 | state = self.thread_states[tid] 218 | 219 | tensor_bytes = tensor_to_bytes(tensor) 220 | tensor_size = len(tensor_bytes) 221 | 222 | if tensor_size > len(state.buffer) - state.memory_used: 223 | self._flush_buffer(state, next_tensor_size=tensor_size) 224 | 225 | local_offset = state.memory_used 226 | state.buffer[state.memory_used:state.memory_used + tensor_size] = tensor_bytes 227 | state.memory_used += tensor_size 228 | 229 | state.sub_header[key] = { 230 | "dtype": DTYPE_REVERSE_MAPPING[tensor.dtype][0], 231 | "shape": list(tensor.shape), 232 | "data_offsets": [local_offset, local_offset + tensor_size] 233 | } 234 | 235 | def __len__(self) -> int: 236 | return len(self.header) 237 | 238 | def _init_buffer(self, header: Mapping[str, TensorMetadata]) -> int: 239 | trimmed_header = { 240 | k: v for k, v in header.items() if v.shape is not None and v.dtype is not None 241 | } 242 | worst_case_header = OrderedDict(sorted( 243 | trimmed_header.items(), 244 | key=lambda item: item[1].get_byte_size(), 245 | reverse=True, # simulate worst case: maximize space taken by order 246 | )) 247 | 248 | data_offset = 0 249 | dummy_safetensors_header = OrderedDict(self.header) 250 | for k, v in worst_case_header.items(): 251 | dummy_safetensors_header[k] = v.safetensors_header_value(data_offset) 252 | data_offset += v.get_byte_size() 253 | 254 | header_json = json.dumps(dummy_safetensors_header, separators=(',', ':')).encode('utf-8') 255 | max_header_size = len(header_json) 256 | self.file.seek(8 + max_header_size) # Reserve space for the header 257 | return max_header_size 258 | 259 | def _flush_buffer(self, state: OutSafetensorsDictThreadState, next_tensor_size: Optional[int] = None, close: bool = False): 260 | if not close: 261 | lock = self.lock 262 | else: 263 | lock = contextlib.nullcontext() 264 | 265 | with lock: 266 | self.file.write(state.buffer[:state.memory_used]) 267 | buffer_offset = self.flushed_size 268 | self.flushed_size += state.memory_used 269 | state.memory_used = 0 270 | 271 | if next_tensor_size is not None: 272 | required_buffer_size = max(self.minimum_buffer_size, next_tensor_size) 273 | if len(state.buffer) < required_buffer_size: 274 | state.buffer = bytearray(required_buffer_size) 275 | else: 276 | state.buffer = state.buffer[:required_buffer_size] 277 | 278 | global_sub_header = { 279 | k: { 280 | attr: val 281 | if attr != "data_offsets" 282 | else (val[0] + buffer_offset, val[1] + buffer_offset) 283 | for attr, val in v.items() 284 | } 285 | for k, v in state.sub_header.items() 286 | } 287 | self.header.update(global_sub_header) 288 | state.sub_header.clear() 289 | if close: 290 | state.buffer = b"" 291 | 292 | def close(self): 293 | with self.lock: 294 | for state in self.thread_states.values(): 295 | self._flush_buffer(state, close=True) 296 | 297 | header_json = json.dumps(self.header, separators=(',', ':')).encode('utf-8') 298 | header_size = len(header_json) 299 | overhead = self.max_header_size - header_size 300 | 301 | if overhead < 0: 302 | # not enough space. we have to move the entire data section by `-overhead` 303 | # this should never happen, but it's here just in case as a fallback 304 | data_offset = -overhead 305 | old_data_section = 8 + self.max_header_size 306 | old_file_end = 8 + self.max_header_size + self.flushed_size 307 | new_file_end = 8 + header_size + self.flushed_size 308 | self.file.truncate(new_file_end) 309 | 310 | # close and reopen the file in read-write mode 311 | self.file.close() 312 | self.file = open(self.file_path, "rb+") 313 | 314 | # move data in chunks from the end to avoid overwriting 315 | for chunk_end in tqdm(range(old_file_end, old_data_section, -self.minimum_buffer_size), desc="Reallocating data section"): 316 | chunk_start = max(chunk_end - self.minimum_buffer_size, old_data_section) 317 | chunk_size = chunk_end - chunk_start 318 | self.file.seek(chunk_start) 319 | data = self.file.read(chunk_size) 320 | 321 | # calculate the new position and write the chunk 322 | self.file.seek(chunk_start + data_offset) 323 | self.file.write(data) 324 | 325 | # we made just enough space for the header 326 | overhead = 0 327 | 328 | self.file.seek(0) 329 | self.file.write(struct.pack(' bytes: 338 | # assume tensor is not spare nor contiguous and on the cpu 339 | total_bytes = len(tensor.untyped_storage()) 340 | 341 | ptr = tensor.data_ptr() 342 | if ptr == 0: 343 | return b"" 344 | newptr = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_ubyte)) 345 | data = numpy.ctypeslib.as_array(newptr, (total_bytes,)) # no internal copy 346 | if sys.byteorder == "big": 347 | NPDTYPES = { 348 | torch.int64: numpy.int64, 349 | torch.float32: numpy.float32, 350 | torch.int32: numpy.int32, 351 | # XXX: This is ok because both have the same width 352 | torch.bfloat16: numpy.float16, 353 | torch.float16: numpy.float16, 354 | torch.int16: numpy.int16, 355 | torch.uint8: numpy.uint8, 356 | torch.int8: numpy.int8, 357 | torch.bool: bool, 358 | torch.float64: numpy.float64, 359 | # XXX: This is ok because both have the same width and byteswap is a no-op anyway 360 | torch.float8_e4m3fn: numpy.uint8, 361 | torch.float8_e5m2: numpy.uint8, 362 | } 363 | npdtype = NPDTYPES[tensor.dtype] 364 | # Not in place as that would potentially modify a live running model 365 | data = data.view(npdtype).byteswap(inplace=False) 366 | return data.tobytes() 367 | 368 | 369 | DTYPE_MAPPING = { 370 | "F64": (torch.float64, 8), 371 | "I64": (torch.int64, 8), 372 | "F32": (torch.float32, 4), 373 | "I32": (torch.int32, 4), 374 | "F16": (torch.float16, 2), 375 | "BF16": (torch.bfloat16, 2), 376 | "I16": (torch.int16, 2), 377 | "I8": (torch.int8, 1), 378 | "F8_E4M3": (torch.float8_e4m3fn, 1), 379 | "F8_E5M2": (torch.float8_e5m2, 1), 380 | "BOOL": (torch.bool, 1), 381 | } 382 | for i, dtype_str in enumerate(("uint8", "uint16", "uint32", "uint64")): 383 | if hasattr(torch, dtype_str): 384 | num_bytes = 2**i 385 | num_bits = 8 * num_bytes 386 | DTYPE_MAPPING[f"U{num_bits}"] = (getattr(torch, dtype_str), num_bytes) 387 | 388 | DTYPE_REVERSE_MAPPING = {v: (k, b) for k, (v, b) in DTYPE_MAPPING.items()} 389 | -------------------------------------------------------------------------------- /sd_mecha/typing_.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing 3 | from types import UnionType 4 | from typing import runtime_checkable, Protocol, TypeVar 5 | 6 | 7 | K = TypeVar("K") 8 | V = TypeVar("V") 9 | 10 | 11 | @runtime_checkable 12 | class WriteOnlyMapping(Protocol[K, V]): 13 | @abc.abstractmethod 14 | def __setitem__(self, key: K, value: V) -> None: 15 | ... 16 | 17 | @abc.abstractmethod 18 | def __len__(self) -> int: 19 | ... 20 | 21 | 22 | def is_subclass(source: type | UnionType, target: type | UnionType): 23 | source_origin = typing.get_origin(source) or source 24 | target_origin = typing.get_origin(target) or target 25 | if isinstance(source_origin, TypeVar): 26 | return False 27 | if isinstance(target_origin, TypeVar): 28 | return any(is_subclass(source_origin, constraint) for constraint in target_origin.__constraints__) 29 | if issubclass(source_origin, UnionType): 30 | return all(is_subclass(arg, target) for arg in typing.get_args(source)) 31 | return issubclass(source_origin, target) 32 | -------------------------------------------------------------------------------- /tests/extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljleb/sd-mecha/8150c2a16074cb8463751ddcddd8572d31830882/tests/extensions/__init__.py -------------------------------------------------------------------------------- /tests/extensions/test_merge_methods.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import tempfile 3 | import typing 4 | import pytest 5 | import safetensors.torch 6 | import torch 7 | import sd_mecha 8 | from typing import TypeVar, Mapping 9 | from sd_mecha import merge_method, Parameter, Return, StateDict, StateDictKeyError 10 | 11 | 12 | A = TypeVar("A") 13 | B = TypeVar("B") 14 | 15 | 16 | def assert_equal_in_merge_method(expected: A, actual_literal: B, t: type[A] | TypeVar): 17 | return_t = next(iter(typing.get_args(t))) if typing.get_args(t) else t 18 | 19 | @merge_method(register=False) 20 | def compare_value( 21 | actual: Parameter(t, model_config="sdxl-supermerger_blocks"), 22 | **kwargs, 23 | ) -> Return(return_t, model_config="sdxl-supermerger_blocks"): 24 | nonlocal expected 25 | 26 | try: 27 | if isinstance(expected, Mapping): 28 | expected_value = expected[kwargs["key"]] 29 | else: 30 | expected_value = expected 31 | except KeyError as e: 32 | raise StateDictKeyError(str(e)) from e 33 | 34 | if isinstance(actual, Mapping): 35 | actual = actual[kwargs["key"]] 36 | 37 | if isinstance(actual, torch.Tensor): 38 | assert torch.allclose(expected_value, actual) 39 | else: 40 | assert actual == expected_value and isinstance(actual, type(expected_value)) 41 | 42 | return actual 43 | 44 | sd_mecha.merge( 45 | compare_value(actual_literal), 46 | strict_weight_space=False, 47 | threads=0, 48 | merge_device=None, 49 | merge_dtype=None, 50 | output_device=None, 51 | output_dtype=None, 52 | ) 53 | 54 | 55 | def test_value_to_node__float_to_tensor(): 56 | actual = 1.0 57 | expected = torch.tensor(1.0) 58 | assert_equal_in_merge_method(expected, actual, torch.Tensor) 59 | 60 | 61 | def test_value_to_node__str(): 62 | actual = "hello!" 63 | expected = "hello!" 64 | assert_equal_in_merge_method(expected, actual, str) 65 | 66 | 67 | def test_value_to_node__float_dict_to_tensor(): 68 | actual = {"IN00": float(1.0), "IN01": float(2.0)} 69 | expected = {"IN00": torch.tensor(1.0), "IN01": torch.tensor(2.0)} 70 | assert_equal_in_merge_method(expected, actual, torch.Tensor) 71 | 72 | 73 | def test_value_to_node__float_dict_to_tensor_dict(): 74 | actual = {"IN00": float(1.0), "IN01": float(2.0)} 75 | expected = {"IN00": torch.tensor(1.0), "IN01": torch.tensor(2.0)} 76 | assert_equal_in_merge_method(expected, actual, StateDict[torch.Tensor]) 77 | 78 | 79 | def test_value_to_node__int_dict_to_tensor_dict(): 80 | actual = {"IN00": int(1), "IN01": int(2)} 81 | expected = {"IN00": torch.tensor(1.0), "IN01": torch.tensor(2.0)} 82 | assert_equal_in_merge_method(expected, actual, StateDict[torch.Tensor]) 83 | 84 | 85 | def test_value_to_node__tensor_dict_to_tensor_dict(): 86 | actual = {"IN00": torch.tensor(1.0), "IN01": torch.tensor(2.0)} 87 | expected = {"IN00": torch.tensor(1.0), "IN01": torch.tensor(2.0)} 88 | assert_equal_in_merge_method(expected, actual, StateDict[torch.Tensor]) 89 | 90 | 91 | def test_value_to_node__float_to_int(): 92 | actual = float(1.5) 93 | expected = int(1) 94 | assert_equal_in_merge_method(expected, actual, int) 95 | 96 | 97 | def test_value_to_node__int_to_float(): 98 | actual = int(1) 99 | expected = float(1.0) 100 | assert_equal_in_merge_method(expected, actual, float) 101 | 102 | 103 | def test_value_to_node__path_to_tensor(): 104 | tmp = tempfile.mktemp(suffix=".safetensors") 105 | actual = pathlib.Path(tmp) 106 | try: 107 | expected = {"IN00": torch.tensor(0.0)} 108 | safetensors.torch.save_file(expected, tmp) 109 | assert_equal_in_merge_method(expected, actual, torch.Tensor) 110 | finally: 111 | pathlib.Path(tmp).unlink(missing_ok=True) 112 | 113 | 114 | def test_value_to_node__str_dict_to_str(): 115 | actual = {"IN00": "hello!", "IN01": "hello2!"} 116 | expected = {"IN00": "hello!", "IN01": "hello2!"} 117 | assert_equal_in_merge_method(expected, actual, str) 118 | 119 | 120 | def test_value_to_node__inconsistent_dict_type(): 121 | value = {"IN00": torch.tensor(1.0), "IN01": 2.0} 122 | with pytest.raises(TypeError): 123 | assert_equal_in_merge_method(value, value, torch.Tensor) 124 | 125 | 126 | def test_value_to_node__path_to_str(): 127 | tmp = tempfile.mktemp(suffix=".safetensors") 128 | actual = pathlib.Path(tmp) 129 | expected = tmp 130 | with pytest.raises(TypeError): 131 | assert_equal_in_merge_method(expected, actual, str) 132 | 133 | 134 | def test_value_to_node__str_dict_to_str_dict(): 135 | actual = {"IN00": "hello!", "IN01": "hello2!"} 136 | expected = {"IN00": "hello!", "IN01": "hello2!"} 137 | assert_equal_in_merge_method(expected, actual, StateDict[str]) 138 | 139 | 140 | T = TypeVar("T") 141 | 142 | 143 | def test_value_to_node__int_to_type_var(): 144 | actual = {"IN00": 1, "IN01": 2} 145 | expected = {"IN00": 1, "IN01": 2} 146 | assert_equal_in_merge_method(expected, actual, T) 147 | 148 | 149 | def test_value_to_node__int_to_type_var_dict(): 150 | actual = {"IN00": 1, "IN01": 2} 151 | expected = {"IN00": 1, "IN01": 2} 152 | assert_equal_in_merge_method(expected, actual, StateDict[T]) 153 | 154 | 155 | def test_value_to_node__tensor_to_float(): 156 | actual = {"IN00": torch.tensor(1.0), "IN01": torch.tensor(2.0)} 157 | expected = {"IN00": float(1.0), "IN01": float(2.0)} 158 | assert_equal_in_merge_method(expected, actual, float) 159 | 160 | 161 | def test_value_to_node__tensor_to_int(): 162 | actual = {"IN00": torch.tensor(1), "IN01": torch.tensor(2)} 163 | expected = {"IN00": int(1), "IN01": int(2)} 164 | assert_equal_in_merge_method(expected, actual, int) 165 | 166 | 167 | def test_value_to_node__path_to_type_var(): 168 | tmp = tempfile.mktemp(suffix=".safetensors") 169 | actual = pathlib.Path(tmp) 170 | try: 171 | expected = {"IN00": torch.tensor(0.0)} 172 | safetensors.torch.save_file(expected, tmp) 173 | assert_equal_in_merge_method(expected, actual, T) 174 | finally: 175 | pathlib.Path(tmp).unlink(missing_ok=True) 176 | 177 | 178 | def test_value_to_node__path_to_type_var_dict(): 179 | tmp = tempfile.mktemp(suffix=".safetensors") 180 | actual = pathlib.Path(tmp) 181 | try: 182 | expected = {"IN00": torch.tensor(0.0)} 183 | safetensors.torch.save_file(expected, tmp) 184 | assert_equal_in_merge_method(expected, actual, StateDict[T]) 185 | finally: 186 | pathlib.Path(tmp).unlink(missing_ok=True) 187 | -------------------------------------------------------------------------------- /tests/merge_methods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ljleb/sd-mecha/8150c2a16074cb8463751ddcddd8572d31830882/tests/merge_methods/__init__.py -------------------------------------------------------------------------------- /tests/merge_methods/test_della.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sd_mecha 3 | 4 | 5 | def test_della(): 6 | k = 1.0 7 | 8 | probability = 0.40 9 | della_eps = 0.20 10 | no_della = 0.0 11 | seed = 114514 12 | cos_eps = 1e-6 13 | 14 | eps = 1e-6 15 | maxiter = 100 16 | ftol = 1e-20 17 | 18 | models = [ 19 | torch.tensor([ 20 | [ 3., 4., 1., -2.], 21 | [ 2., 1., -4., 3.], 22 | [-1., 2., 3., 4.], 23 | [ 4., -3., 2., 1.], 24 | ]), 25 | torch.tensor([ 26 | [-1., 3., 4., 2.], 27 | [ 4., 2., -3., 1.], 28 | [ 3., -1., 2., 4.], 29 | [ 2., 4., 1., -3.], 30 | ]), 31 | torch.tensor([ 32 | [-1., 3., 2., 0.], 33 | [ 3., 0., 2., 1.], 34 | [-1., 3., 1., 0.], 35 | [ 3., 0., -4., 3.] 36 | ]) 37 | ] 38 | 39 | expected = torch.tensor([ 40 | [ 0.6251, 1.8496, 2.4691, 0.0000], 41 | [ 3.1303, 0.2084, -0.8335, 1.0780], 42 | [ 1.0162, 1.7755, 1.0780, 0.0000], 43 | [ 2.0362, 0.0000, 0.4167, 0.2084] 44 | ]) 45 | 46 | expected2 = torch.tensor([ 47 | [ 0.7469, 1.9883, 2.4127, 0.0000], 48 | [ 1.7586, 0.9106, -0.9959, 1.1671], 49 | [ 0.9925, 1.7586, 1.1671, 0.0000], 50 | [ 1.9223, 0.0000, 0.4979, 0.2490] 51 | ]) 52 | 53 | test_no_della = sd_mecha.ties_sum_with_dropout.__wrapped__( 54 | *models, 55 | probability=probability, 56 | della_eps=no_della, 57 | rescale=False, 58 | k=k, 59 | vote_sgn=True, 60 | seed=seed, 61 | apply_stock=False, 62 | apply_median=True, 63 | cos_eps=cos_eps, 64 | eps=eps, 65 | maxiter=maxiter, 66 | ftol=ftol, 67 | ) 68 | assert torch.allclose(test_no_della, expected, atol=0.0001) 69 | 70 | actual_della = sd_mecha.ties_sum_with_dropout.__wrapped__( 71 | *models, 72 | probability=probability, 73 | della_eps=della_eps, 74 | rescale=False, 75 | k=k, 76 | vote_sgn=True, 77 | seed=seed, 78 | apply_stock=False, 79 | apply_median=True, 80 | cos_eps=cos_eps, 81 | eps=eps, 82 | maxiter=maxiter, 83 | ftol=ftol, 84 | ) 85 | assert torch.allclose(actual_della, expected, atol=0.0001) 86 | 87 | actual_della_flipped = sd_mecha.ties_sum_with_dropout.__wrapped__( 88 | *models, 89 | probability=probability, 90 | della_eps=-della_eps, 91 | rescale=False, 92 | k=k, 93 | vote_sgn=True, 94 | seed=seed, 95 | apply_stock=False, 96 | apply_median=True, 97 | cos_eps=cos_eps, 98 | eps=eps, 99 | maxiter=maxiter, 100 | ftol=ftol, 101 | ) 102 | assert torch.allclose(actual_della_flipped, expected2, atol=0.0001) 103 | -------------------------------------------------------------------------------- /tests/merge_methods/test_geometric_median.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sd_mecha 3 | 4 | 5 | def test_geometric_median(): 6 | eps = 1e-6 7 | maxiter = 100 8 | ftol = 1e-20 9 | 10 | models = [ 11 | torch.tensor([ 12 | [ 3., 4., 1., -2.], 13 | [ 2., 1., -4., 3.], 14 | [-1., 2., 3., 4.], 15 | [ 4., -3., 2., 1.], 16 | ]), 17 | torch.tensor([ 18 | [-1., 3., 4., 2.], 19 | [ 4., 2., -3., 1.], 20 | [ 3., -1., 2., 4.], 21 | [ 2., 4., 1., -3.], 22 | ]), 23 | torch.tensor([ 24 | [-1., 3., 2., 0.], 25 | [ 3., 0., 2., 1.], 26 | [-1., 3., 1., 0.], 27 | [ 3., 0., -4., 3.] 28 | ]) 29 | ] 30 | 31 | models2 = [] 32 | for i in range(100): 33 | models2.append(torch.rand(1280, 1280)) 34 | 35 | expected = torch.tensor([ 36 | [0.4791, 3.3698, 2.2343, -0.1354], 37 | [2.9323, 0.9739, -1.7289, 1.7395], 38 | [0.2082, 1.4220, 2.0416, 2.6873], 39 | [3.0677, 0.0989, -0.2711, 0.4481] 40 | ]) 41 | 42 | median = sd_mecha.geometric_median.__wrapped__(*models, eps=eps, maxiter=maxiter, ftol=ftol) 43 | assert torch.allclose(median, expected, atol=0.0001) 44 | -------------------------------------------------------------------------------- /tests/merge_methods/test_modelstock.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sd_mecha 3 | 4 | 5 | def test_modelstock(): 6 | cos_eps = 1e-6 7 | 8 | models = [ 9 | torch.tensor([ 10 | [3., 4., 1., -2.], 11 | [2., 1., -4., 3.], 12 | [-1., 2., 3., 4.], 13 | [4., -3., 2., 1.], 14 | ]), 15 | torch.tensor([ 16 | [-1., 3., 4., 2.], 17 | [4., 2., -3., 1.], 18 | [3., -1., 2., 4.], 19 | [2., 4., 1., -3.], 20 | ]), 21 | torch.tensor([ 22 | [-1., 2., 3., 4.], 23 | [4., -3., 2., 1.], 24 | [3., 4., 1., -2.], 25 | [2., 1., -4., 3.], 26 | ]) 27 | ] 28 | 29 | expected1 = torch.tensor([ 30 | [0.2727, 2.4545, 2.1818, 1.0909], 31 | [2.5000, 0.0000, -1.2500, 1.2500], 32 | [0.8696, 0.8696, 1.0435, 1.0435], 33 | [-2.0000, -0.5000, 0.2500, -0.2500] 34 | ]) 35 | 36 | stock_only = sd_mecha.model_stock.__wrapped__(*models, cos_eps=cos_eps) 37 | assert torch.allclose(stock_only, expected1, atol=0.0001) 38 | -------------------------------------------------------------------------------- /tests/merge_methods/test_ties.py: -------------------------------------------------------------------------------- 1 | """ 2 | ## Comment / LaTeX translation ## 3 | - View in https://upmath.me/ 4 | ### add_difference_ties ### 5 | - `base`: $$ \theta_{init} $$ 6 | - `*models`: $$ \{\theta_{init}\}_{t=1}^n $$ 7 | - `models` after `subtract`: $$ \tau_t $$ 8 | - `alpha`: $$ \lambda $$ 9 | - `k`: $$ k $$ ( From $$ \% $$ to $$ 1 $$ ) 10 | - `res`: $$ \lambda * \tau_m $$ 11 | - `return`: $$ \theta_m $$ 12 | ### ties_sum ### 13 | - `delta`: $$ \hat{\tau}_t $$ 14 | - `signs`: $$ \gamma_t $$ 15 | - `final_sign`: $$ \gamma_m^p = sgn(\sum_{t=1}^n \hat{\tau}_t^p) $$ 16 | - `delta_filters`: $$ \{ \gamma_t^p = \gamma_m^p \} $$ 17 | - `param_counts`: $$ |A^p| $$ 18 | - `filtered_delta`: $$ \sum_{t\in{A^p}} \hat{\tau}_t^p $$ 19 | - `return`: $$ \lambda * \tau_m $$ 20 | """ 21 | 22 | import torch 23 | import sd_mecha 24 | 25 | 26 | def test_ties(): 27 | k = 0.5 28 | 29 | models = [ 30 | torch.tensor([ 31 | [-1., 2., 3., 4.], 32 | [4., -3., 2., 1.], 33 | [3., 4., 1., -2.], 34 | [2., 1., -4., 3.], 35 | ]), 36 | torch.tensor([ 37 | [3., 4., 1., -2.], 38 | [2., 1., -4., 3.], 39 | [-1., 2., 3., 4.], 40 | [4., -3., 2., 1.], 41 | ]) 42 | ] 43 | expected = torch.tensor([ 44 | [3., 3., 3., 0.], 45 | [3., -3., 0., 3.], 46 | [3., 3., 3., 0.], 47 | [3., -3., 0., 3.] 48 | ]) 49 | 50 | actual = sd_mecha.ties_sum.__wrapped__(*models, k=k) 51 | assert torch.allclose(actual, expected) 52 | 53 | actual2 = sd_mecha.ties_sum.__wrapped__( 54 | *models, 55 | k=k, 56 | vote_sgn=True, 57 | ) 58 | assert not torch.allclose(actual, actual2) 59 | --------------------------------------------------------------------------------