├── .github └── workflows │ ├── cla.yml │ └── pre-commit.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CLA.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── create_a_merge_method.md ├── evolve.md ├── merge_methods.md ├── moe.md └── multimerge.md ├── examples ├── arcee_fusion.yml ├── bio-merge.yml ├── gradient-slerp.yml ├── linear.yml ├── orcamini-platy-44layer.yml └── ties.yml ├── mergekit ├── __init__.py ├── _data │ ├── __init__.py │ ├── architectures │ │ ├── __init__.py │ │ ├── baichuan.json │ │ ├── bert-masked-lm.json │ │ ├── bert-sequence-classification.json │ │ ├── bert.json │ │ ├── chatglm.json │ │ ├── cohere.json │ │ ├── distilbert-masked-lm.json │ │ ├── distilbert-sequence-classification.json │ │ ├── distilbert-token-classification.json │ │ ├── distilbert.json │ │ ├── exaone.json │ │ ├── falcon.json │ │ ├── gemma.json │ │ ├── gemma2.json │ │ ├── gemma3.json │ │ ├── gemma3vl.json │ │ ├── glm4.json │ │ ├── gpt-neox.json │ │ ├── gpt2-sequence-classification.json │ │ ├── gpt2.json │ │ ├── gptbigcode.json │ │ ├── internlm2.json │ │ ├── jais.json │ │ ├── llama.json │ │ ├── llama4.json │ │ ├── mamba.json │ │ ├── mistral.json │ │ ├── mistral3.json │ │ ├── olmo2.json │ │ ├── phi-1.json │ │ ├── phi2-old.json │ │ ├── phi2.json │ │ ├── phi3-small.json │ │ ├── phi3.json │ │ ├── qwen.json │ │ ├── qwen2.json │ │ ├── qwen3.json │ │ ├── roberta-masked-lm.json │ │ ├── roberta-sequence-classification.json │ │ ├── roberta-token-classification.json │ │ ├── roberta.json │ │ ├── solar.json │ │ ├── stablelm.json │ │ ├── stablelm2.json │ │ ├── starcoder2.json │ │ ├── t5.json │ │ └── whisper.json │ └── chat_templates │ │ ├── __init__.py │ │ ├── alpaca.jinja │ │ ├── chatml.jinja │ │ ├── exaone.jinja │ │ ├── llama3.jinja │ │ └── mistral.jinja ├── architecture │ ├── __init__.py │ ├── auto.py │ ├── base.py │ ├── json_definitions.py │ └── moe_defs.py ├── card.py ├── common.py ├── config.py ├── evo │ ├── __init__.py │ ├── actors.py │ ├── config.py │ ├── genome.py │ ├── helpers.py │ ├── monkeypatch.py │ └── strategy.py ├── graph.py ├── io │ ├── __init__.py │ ├── lazy_tensor_loader.py │ ├── lazy_unpickle.py │ ├── loader.py │ ├── tasks.py │ └── tensor_writer.py ├── merge.py ├── merge_methods │ ├── __init__.py │ ├── arcee_fusion.py │ ├── base.py │ ├── easy_define.py │ ├── generalized_task_arithmetic.py │ ├── karcher.py │ ├── linear.py │ ├── model_stock.py │ ├── multislerp.py │ ├── nearswap.py │ ├── nuslerp.py │ ├── passthrough.py │ ├── rectify_embed.py │ ├── registry.py │ ├── sce.py │ └── slerp.py ├── moe │ ├── __init__.py │ ├── arch.py │ ├── common.py │ ├── config.py │ ├── deepseek.py │ ├── mixtral.py │ ├── qwen.py │ └── router.py ├── multigpu_executor.py ├── options.py ├── plan.py ├── scripts │ ├── __init__.py │ ├── bakllama.py │ ├── evolve.py │ ├── extract_lora.py │ ├── fill_missing_params.py │ ├── layershuffle.py │ ├── legacy.py │ ├── merge_raw_pytorch.py │ ├── moe.py │ ├── multimerge.py │ ├── run_yaml.py │ └── tokensurgeon.py ├── sparsify.py └── tokenizer │ ├── __init__.py │ ├── build.py │ ├── config.py │ └── embed.py ├── notebook.ipynb ├── pyproject.toml └── tests ├── __init__.py ├── common.py ├── test_basic_merges.py ├── test_chat_template.py ├── test_graph.py ├── test_io.py ├── test_lazy_unpickle.py ├── test_modelref.py ├── test_sparsify.py └── test_tokenizer.py /.github/workflows/cla.yml: -------------------------------------------------------------------------------- 1 | name: "CLA Assistant" 2 | on: 3 | issue_comment: 4 | types: [created] 5 | pull_request_target: 6 | types: [opened,closed,synchronize] 7 | 8 | permissions: 9 | actions: write 10 | contents: read 11 | pull-requests: write 12 | statuses: write 13 | 14 | jobs: 15 | CLAAssistant: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: "CLA Assistant" 19 | if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' 20 | uses: contributor-assistant/github-action@v2.6.1 21 | env: 22 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 23 | PERSONAL_ACCESS_TOKEN: ${{ secrets.CLA_ACCESS_TOKEN }} 24 | with: 25 | path-to-signatures: 'signatures/version1/cla.json' 26 | path-to-document: 'https://github.com/arcee-ai/mergekit/blob/main/CLA.md' 27 | remote-organization-name: 'arcee-ai' 28 | remote-repository-name: 'cla-signatures' 29 | branch: 'signatures' 30 | allowlist: bot*,dependabot* 31 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | 7 | jobs: 8 | pre-commit: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Install uv 13 | uses: astral-sh/setup-uv@v5 14 | with: 15 | version: "0.5.29" 16 | - name: Install Python 17 | run: uv python install 18 | - name: Run pre-commit 19 | run: uv run --no-sync --with pre-commit-uv pre-commit run --show-diff-on-failure --color=always --all-files 20 | 21 | pytest: 22 | if: github.ref == 'refs/heads/main' || github.event_name == 'pull_request' 23 | name: PyTest 24 | needs: [pre-commit] 25 | runs-on: ubuntu-latest 26 | strategy: 27 | fail-fast: false 28 | matrix: 29 | python_version: ["3.10", "3.11", "3.12"] 30 | timeout-minutes: 5 31 | 32 | steps: 33 | - uses: actions/checkout@v4 34 | - name: Install uv 35 | uses: astral-sh/setup-uv@v5 36 | with: 37 | version: "0.5.29" 38 | - name: Install Project 39 | run: uv python pin ${{ matrix.python_version }} && uv sync --dev --no-extra vllm --extra test 40 | - name: Run PyTest 41 | run: uv run pytest 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-added-large-files 6 | - id: check-yaml 7 | args: ["--allow-multiple-documents"] 8 | - id: check-json 9 | - id: check-ast 10 | - repo: https://github.com/PyCQA/isort 11 | rev: 6.0.0 12 | hooks: 13 | - id: isort 14 | - repo: https://github.com/psf/black 15 | rev: 25.1.0 16 | hooks: 17 | - id: black 18 | - repo: https://github.com/pre-commit/pre-commit-hooks 19 | rev: v5.0.0 20 | hooks: 21 | - id: trailing-whitespace 22 | - id: end-of-file-fixer 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | License text copyright (c) 2020 MariaDB Corporation Ab, All Rights Reserved. 2 | "Business Source License" is a trademark of MariaDB Corporation Ab. 3 | 4 | Parameters 5 | 6 | Licensor: Arcee AI 7 | Licensed Work: MergeKit Version 0.1.0 or later. The Licensed Work is (c) 2025 8 | Arcee AI. 9 | Additional Use Grant: You may make production use of the Licensed Work so long as 10 | you and your affiliates, considered both individually and in 11 | the aggregate, do not meet any of the following criteria: 12 | (i) have more than 100 total full-time equivalent employees 13 | and/or contractors, (ii) have more than USD $10 million in 14 | annual recurring revenue, (iii) make available products and 15 | services with more than 1 million daily active users, (iv) are 16 | a publicly traded company with a market capitalization of 17 | greater than USD $300 million, or (v) are a private company 18 | whose most recent financing post-money valuation was greater 19 | than USD $300 million. 20 | 21 | For clarity, the use of the Licensed Work to create models for 22 | production use constitutes production use. 23 | Change Date: Two years from the date the Licensed Work is published. 24 | Change License: GNU Lesser General Public License v3.0 or later 25 | 26 | For information about alternative licensing arrangements for the Licensed Work, 27 | please contact licensing@arcee.ai. 28 | 29 | Notice 30 | 31 | Business Source License 1.1 32 | 33 | Terms 34 | 35 | The Licensor hereby grants you the right to copy, modify, create derivative 36 | works, redistribute, and make non-production use of the Licensed Work. The 37 | Licensor may make an Additional Use Grant, above, permitting limited production use. 38 | 39 | Effective on the Change Date, or the fourth anniversary of the first publicly 40 | available distribution of a specific version of the Licensed Work under this 41 | License, whichever comes first, the Licensor hereby grants you rights under 42 | the terms of the Change License, and the rights granted in the paragraph 43 | above terminate. 44 | 45 | If your use of the Licensed Work does not comply with the requirements 46 | currently in effect as described in this License, you must purchase a 47 | commercial license from the Licensor, its affiliated entities, or authorized 48 | resellers, or you must refrain from using the Licensed Work. 49 | 50 | All copies of the original and modified Licensed Work, and derivative works 51 | of the Licensed Work, are subject to this License. This License applies 52 | separately for each version of the Licensed Work and the Change Date may vary 53 | for each version of the Licensed Work released by Licensor. 54 | 55 | You must conspicuously display this License on each original or modified copy 56 | of the Licensed Work. If you receive the Licensed Work in original or 57 | modified form from a third party, the terms and conditions set forth in this 58 | License apply to your use of that work. 59 | 60 | Any use of the Licensed Work in violation of this License will automatically 61 | terminate your rights under this License for the current and all other 62 | versions of the Licensed Work. 63 | 64 | This License does not grant you any right in any trademark or logo of 65 | Licensor or its affiliates (provided that you may use a trademark or logo of 66 | Licensor as expressly required by this License). 67 | 68 | TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON 69 | AN "AS IS" BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS, 70 | EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF 71 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND 72 | TITLE. 73 | -------------------------------------------------------------------------------- /docs/multimerge.md: -------------------------------------------------------------------------------- 1 | # mergekit-multi: Multi-Stage Model Merging 2 | 3 | ## What is mergekit-multi? 4 | 5 | `mergekit-multi` is a command-line tool for executing complex model merging workflows with multiple interdependent stages. It allows you to: 6 | 7 | 1. Chain multiple merge operations together 8 | 2. Use outputs from previous merges as inputs to subsequent ones 9 | 3. Automatically handle dependencies between merge steps 10 | 4. Cache intermediate results for faster re-runs 11 | 12 | ## Usage 13 | 14 | Basic command structure: 15 | ```bash 16 | mergekit-multi \ 17 | --intermediate-dir ./intermediates \ 18 | ([--out-path ./final-merge] | if config has unnamed merge) \ 19 | [options] 20 | ``` 21 | 22 | ## Configuration File Format 23 | 24 | Create a YAML file with multiple merge configurations separated by `---`. Each should contain: 25 | 26 | - `name`: Unique identifier for intermediate merges (except final merge) 27 | - Standard mergekit configuration parameters 28 | 29 | Example with Final Merge (`multimerge.yaml`): 30 | ```yaml 31 | name: first-merge 32 | merge_method: linear 33 | models: 34 | - model: mistralai/Mistral-7B-v0.1 35 | - model: BioMistral/BioMistral-7B 36 | parameters: 37 | weight: 0.5 38 | --- 39 | name: second-merge 40 | merge_method: slerp 41 | base_model: first-merge # Reference previous merge 42 | models: 43 | - model: NousResearch/Hermes-2-Pro-Mistral-7B 44 | parameters: 45 | t: 0.5 46 | --- 47 | # Final merge (no name) 48 | merge_method: dare_ties 49 | base_model: mistralai/Mistral-7B-v0.1 50 | models: 51 | - model: second-merge 52 | parameters: 53 | density: 0.6 54 | weight: 0.5 55 | - model: teknium/OpenHermes-2.5-Mistral-7B 56 | parameters: 57 | density: 0.8 58 | weight: 0.5 59 | ``` 60 | 61 | ### Example with All Named Merges: 62 | ```yaml 63 | name: first-merge 64 | merge_method: task_arithmetic 65 | ... 66 | --- 67 | name: second-merge 68 | merge_method: slerp 69 | ... 70 | --- 71 | name: third-merge 72 | merge_method: linear 73 | ... 74 | ``` 75 | 76 | ## Key Options 77 | 78 | - `--intermediate-dir`: Directory to store partial merge results 79 | - `--out-path`: Output path for final merge (only applies when one merge has no `name`) 80 | - `--lazy/--no-lazy`: Don't rerun existing intermediate merges (default: true) 81 | - Standard mergekit options apply (e.g., `--cuda`, `--out-shard-size`, `--multi-gpu`) 82 | 83 | ## How It Works 84 | 85 | When you run `mergekit-multi`, it topologically sorts your merge configurations to determine the correct order of execution. The merges are then processed sequentially, using outputs from previous steps as inputs for subsequent ones as needed. 86 | 87 | All intermediate merges are saved in your specified `--intermediate-dir` using their configured names. By default, the tool will skip any merge operations that already have existing output files. To force re-execution of all merges, use the `--no-lazy` flag. 88 | -------------------------------------------------------------------------------- /examples/arcee_fusion.yml: -------------------------------------------------------------------------------- 1 | models: 2 | - model: model_a 3 | - model: model_b 4 | merge_method: arcee_fusion 5 | base_model: model_a 6 | dtype: bfloat16 7 | -------------------------------------------------------------------------------- /examples/bio-merge.yml: -------------------------------------------------------------------------------- 1 | models: 2 | - model: mistralai/Mistral-7B-Instruct-v0.2 3 | parameters: 4 | density: 0.5 5 | weight: 0.5 6 | - model: BioMistral/BioMistral-7B 7 | parameters: 8 | density: 0.5 9 | weight: 0.5 10 | merge_method: ties 11 | base_model: mistralai/Mistral-7B-v0.1 12 | parameters: 13 | normalize: false 14 | int8_mask: true 15 | dtype: float16 16 | -------------------------------------------------------------------------------- /examples/gradient-slerp.yml: -------------------------------------------------------------------------------- 1 | slices: 2 | - sources: 3 | - model: psmathur/orca_mini_v3_13b 4 | layer_range: [0, 40] 5 | - model: garage-bAInd/Platypus2-13B 6 | layer_range: [0, 40] 7 | # or, the equivalent models: syntax: 8 | # models: 9 | # - model: psmathur/orca_mini_v3_13b 10 | # - model: garage-bAInd/Platypus2-13B 11 | merge_method: slerp 12 | base_model: psmathur/orca_mini_v3_13b 13 | parameters: 14 | t: 15 | - filter: self_attn 16 | value: [0, 0.5, 0.3, 0.7, 1] 17 | - filter: mlp 18 | value: [1, 0.5, 0.7, 0.3, 0] 19 | - value: 0.5 # fallback for rest of tensors 20 | dtype: float16 21 | -------------------------------------------------------------------------------- /examples/linear.yml: -------------------------------------------------------------------------------- 1 | models: 2 | - model: psmathur/orca_mini_v3_13b 3 | parameters: 4 | weight: 1.0 5 | - model: WizardLM/WizardLM-13B-V1.2 6 | parameters: 7 | weight: 0.3 8 | - model: garage-bAInd/Platypus2-13B 9 | parameters: 10 | weight: 0.5 11 | merge_method: linear 12 | dtype: float16 13 | -------------------------------------------------------------------------------- /examples/orcamini-platy-44layer.yml: -------------------------------------------------------------------------------- 1 | slices: 2 | - sources: 3 | - model: psmathur/orca_mini_v3_13b 4 | layer_range: [0, 24] 5 | - sources: 6 | - model: garage-bAInd/Platypus2-13B 7 | layer_range: [20, 40] 8 | merge_method: passthrough 9 | dtype: float16 10 | -------------------------------------------------------------------------------- /examples/ties.yml: -------------------------------------------------------------------------------- 1 | models: 2 | - model: psmathur/orca_mini_v3_13b 3 | parameters: 4 | density: [1, 0.7, 0.1] # density gradient 5 | weight: 1.0 6 | - model: garage-bAInd/Platypus2-13B 7 | parameters: 8 | density: 0.5 9 | weight: [0, 0.3, 0.7, 1] # weight gradient 10 | - model: WizardLM/WizardMath-13B-V1.0 11 | parameters: 12 | density: 0.33 13 | weight: 14 | - filter: mlp 15 | value: 0.5 16 | - value: 0 17 | merge_method: ties 18 | base_model: TheBloke/Llama-2-13B-fp16 19 | parameters: 20 | normalize: true 21 | int8_mask: true 22 | dtype: float16 23 | -------------------------------------------------------------------------------- /mergekit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arcee-ai/mergekit/488957e8e67c82861ecf63ef761f6bc59122dc74/mergekit/__init__.py -------------------------------------------------------------------------------- /mergekit/_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arcee-ai/mergekit/488957e8e67c82861ecf63ef761f6bc59122dc74/mergekit/_data/__init__.py -------------------------------------------------------------------------------- /mergekit/_data/architectures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arcee-ai/mergekit/488957e8e67c82861ecf63ef761f6bc59122dc74/mergekit/_data/architectures/__init__.py -------------------------------------------------------------------------------- /mergekit/_data/architectures/baichuan.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "baichuan", 3 | "architectures": [ 4 | "BaichuanForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "model.norm.weight" 15 | }, 16 | { 17 | "name": "lm_head.weight", 18 | "is_embed": true 19 | } 20 | ], 21 | "num_layers_config_key": "num_hidden_layers", 22 | "layer_templates": { 23 | "weights": [ 24 | { 25 | "name": "model.layers.${layer_index}.input_layernorm.weight" 26 | }, 27 | { 28 | "name": "model.layers.${layer_index}.self_attn.W_pack.weight" 29 | }, 30 | { 31 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight" 32 | }, 33 | { 34 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight" 35 | }, 36 | { 37 | "name": "model.layers.${layer_index}.mlp.gate_proj.weight" 38 | }, 39 | { 40 | "name": "model.layers.${layer_index}.mlp.down_proj.weight" 41 | }, 42 | { 43 | "name": "model.layers.${layer_index}.mlp.up_proj.weight" 44 | } 45 | ] 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/bert-masked-lm.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "bert", 3 | "architectures": [ 4 | "BertForMaskedLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "bert.embeddings.position_embeddings.weight" 9 | }, 10 | { 11 | "name": "bert.embeddings.token_type_embeddings.weight" 12 | }, 13 | { 14 | "name": "bert.embeddings.word_embeddings.weight", 15 | "is_embed": true 16 | }, 17 | { 18 | "name": "bert.embeddings.LayerNorm.bias", 19 | "aliases": [ 20 | "bert.embeddings.LayerNorm.beta" 21 | ] 22 | }, 23 | { 24 | "name": "bert.embeddings.LayerNorm.weight", 25 | "aliases": [ 26 | "bert.embeddings.LayerNorm.gamma" 27 | ] 28 | }, 29 | { 30 | "name": "bert.embeddings.position_ids", 31 | "optional": true, 32 | "force_dtype": "int64" 33 | } 34 | ], 35 | "post_weights": [ 36 | { 37 | "name": "bert.pooler.dense.weight" 38 | }, 39 | { 40 | "name": "bert.pooler.dense.bias" 41 | }, 42 | { 43 | "name": "cls.predictions.bias" 44 | }, 45 | { 46 | "name": "cls.predictions.decoder.weight", 47 | "optional": true, 48 | "tied_names": [ 49 | "bert.embeddings.word_embeddings.weight" 50 | ], 51 | "is_embed": true 52 | } 53 | ], 54 | "num_layers_config_key": "num_hidden_layers", 55 | "layer_templates": { 56 | "weights": [ 57 | { 58 | "name": "bert.encoder.layer.${layer_index}.attention.self.query.weight" 59 | }, 60 | { 61 | "name": "bert.encoder.layer.${layer_index}.attention.self.query.bias" 62 | }, 63 | { 64 | "name": "bert.encoder.layer.${layer_index}.attention.self.key.weight" 65 | }, 66 | { 67 | "name": "bert.encoder.layer.${layer_index}.attention.self.key.bias" 68 | }, 69 | { 70 | "name": "bert.encoder.layer.${layer_index}.attention.self.value.weight" 71 | }, 72 | { 73 | "name": "bert.encoder.layer.${layer_index}.attention.self.value.bias" 74 | }, 75 | { 76 | "name": "bert.encoder.layer.${layer_index}.attention.output.dense.weight" 77 | }, 78 | { 79 | "name": "bert.encoder.layer.${layer_index}.attention.output.dense.bias" 80 | }, 81 | { 82 | "name": "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.bias", 83 | "aliases": [ 84 | "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.beta" 85 | ] 86 | }, 87 | { 88 | "name": "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.weight", 89 | "aliases": [ 90 | "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.gamma" 91 | ] 92 | }, 93 | { 94 | "name": "bert.encoder.layer.${layer_index}.intermediate.dense.weight" 95 | }, 96 | { 97 | "name": "bert.encoder.layer.${layer_index}.intermediate.dense.bias" 98 | }, 99 | { 100 | "name": "bert.encoder.layer.${layer_index}.output.dense.weight" 101 | }, 102 | { 103 | "name": "bert.encoder.layer.${layer_index}.output.dense.bias" 104 | }, 105 | { 106 | "name": "bert.encoder.layer.${layer_index}.output.LayerNorm.bias", 107 | "aliases": [ 108 | "bert.encoder.layer.${layer_index}.output.LayerNorm.beta" 109 | ] 110 | }, 111 | { 112 | "name": "bert.encoder.layer.${layer_index}.output.LayerNorm.weight", 113 | "aliases": [ 114 | "bert.encoder.layer.${layer_index}.output.LayerNorm.gamma" 115 | ] 116 | } 117 | ] 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/bert-sequence-classification.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "bert", 3 | "architectures": [ 4 | "BertForSequenceClassification", 5 | "BertForMultipleChoice", 6 | "BertForTokenClassification" 7 | ], 8 | "pre_weights": [ 9 | { 10 | "name": "bert.embeddings.position_embeddings.weight" 11 | }, 12 | { 13 | "name": "bert.embeddings.token_type_embeddings.weight" 14 | }, 15 | { 16 | "name": "bert.embeddings.word_embeddings.weight", 17 | "is_embed": true 18 | }, 19 | { 20 | "name": "bert.embeddings.LayerNorm.bias", 21 | "aliases": [ 22 | "bert.embeddings.LayerNorm.beta" 23 | ] 24 | }, 25 | { 26 | "name": "bert.embeddings.LayerNorm.weight", 27 | "aliases": [ 28 | "bert.embeddings.LayerNorm.gamma" 29 | ] 30 | }, 31 | { 32 | "name": "bert.embeddings.position_ids", 33 | "optional": true, 34 | "force_dtype": "int64" 35 | } 36 | ], 37 | "post_weights": [ 38 | { 39 | "name": "bert.pooler.dense.weight", 40 | "optional": true 41 | }, 42 | { 43 | "name": "bert.pooler.dense.bias", 44 | "optional": true 45 | }, 46 | { 47 | "name": "classifier.bias" 48 | }, 49 | { 50 | "name": "classifier.weight" 51 | } 52 | ], 53 | "num_layers_config_key": "num_hidden_layers", 54 | "layer_templates": { 55 | "weights": [ 56 | { 57 | "name": "bert.encoder.layer.${layer_index}.attention.self.query.weight" 58 | }, 59 | { 60 | "name": "bert.encoder.layer.${layer_index}.attention.self.query.bias" 61 | }, 62 | { 63 | "name": "bert.encoder.layer.${layer_index}.attention.self.key.weight" 64 | }, 65 | { 66 | "name": "bert.encoder.layer.${layer_index}.attention.self.key.bias" 67 | }, 68 | { 69 | "name": "bert.encoder.layer.${layer_index}.attention.self.value.weight" 70 | }, 71 | { 72 | "name": "bert.encoder.layer.${layer_index}.attention.self.value.bias" 73 | }, 74 | { 75 | "name": "bert.encoder.layer.${layer_index}.attention.output.dense.weight" 76 | }, 77 | { 78 | "name": "bert.encoder.layer.${layer_index}.attention.output.dense.bias" 79 | }, 80 | { 81 | "name": "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.bias", 82 | "aliases": [ 83 | "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.beta" 84 | ] 85 | }, 86 | { 87 | "name": "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.weight", 88 | "aliases": [ 89 | "bert.encoder.layer.${layer_index}.attention.output.LayerNorm.gamma" 90 | ] 91 | }, 92 | { 93 | "name": "bert.encoder.layer.${layer_index}.intermediate.dense.weight" 94 | }, 95 | { 96 | "name": "bert.encoder.layer.${layer_index}.intermediate.dense.bias" 97 | }, 98 | { 99 | "name": "bert.encoder.layer.${layer_index}.output.dense.weight" 100 | }, 101 | { 102 | "name": "bert.encoder.layer.${layer_index}.output.dense.bias" 103 | }, 104 | { 105 | "name": "bert.encoder.layer.${layer_index}.output.LayerNorm.bias", 106 | "aliases": [ 107 | "bert.encoder.layer.${layer_index}.output.LayerNorm.beta" 108 | ] 109 | }, 110 | { 111 | "name": "bert.encoder.layer.${layer_index}.output.LayerNorm.weight", 112 | "aliases": [ 113 | "bert.encoder.layer.${layer_index}.output.LayerNorm.gamma" 114 | ] 115 | } 116 | ] 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/chatglm.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "chatglm", 3 | "architectures": [ 4 | "ChatGLMModel" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "transformer.embedding.word_embeddings.weight", 9 | "is_embed": true 10 | }, 11 | { 12 | "name": "transformer.rotary_pos_emb.inv_freq" 13 | } 14 | ], 15 | "post_weights": [ 16 | { 17 | "name": "transformer.encoder.final_layernorm.weight" 18 | }, 19 | { 20 | "name": "transformer.output_layer.weight", 21 | "is_embed": true 22 | } 23 | ], 24 | "num_layers_config_key": "num_hidden_layers", 25 | "layer_templates": { 26 | "weights": [ 27 | { 28 | "name": "transformer.encoder.layers.${layer_index}.input_layernorm.weight" 29 | }, 30 | { 31 | "name": "transformer.encoder.layers.${layer_index}.mlp.dense_4h_to_h.weight" 32 | }, 33 | { 34 | "name": "transformer.encoder.layers.${layer_index}.mlp.dense_h_to_4h.weight" 35 | }, 36 | { 37 | "name": "transformer.encoder.layers.${layer_index}.post_attention_layernorm.weight" 38 | }, 39 | { 40 | "name": "transformer.encoder.layers.${layer_index}.self_attention.dense.weight" 41 | }, 42 | { 43 | "name": "transformer.encoder.layers.${layer_index}.self_attention.query_key_value.bias" 44 | }, 45 | { 46 | "name": "transformer.encoder.layers.${layer_index}.self_attention.query_key_value.weight" 47 | } 48 | ] 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/cohere.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "cohere", 3 | "architectures": [ 4 | "CohereForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "model.norm.weight" 15 | }, 16 | { 17 | "name": "lm_head.weight", 18 | "is_embed": true, 19 | "optional": true 20 | } 21 | ], 22 | "num_layers_config_key": "num_hidden_layers", 23 | "layer_templates": { 24 | "weights": [ 25 | { 26 | "name": "model.layers.${layer_index}.input_layernorm.weight" 27 | }, 28 | { 29 | "name": "model.layers.${layer_index}.mlp.down_proj.weight" 30 | }, 31 | { 32 | "name": "model.layers.${layer_index}.mlp.gate_proj.weight" 33 | }, 34 | { 35 | "name": "model.layers.${layer_index}.mlp.up_proj.weight" 36 | }, 37 | { 38 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight" 39 | }, 40 | { 41 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight" 42 | }, 43 | { 44 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight" 45 | }, 46 | { 47 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight" 48 | } 49 | ] 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/distilbert-masked-lm.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "distilbert", 3 | "architectures": [ 4 | "DistilBertForMaskedLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "distilbert.embeddings.position_embeddings.weight" 9 | }, 10 | { 11 | "name": "distilbert.embeddings.word_embeddings.weight", 12 | "is_embed": true 13 | }, 14 | { 15 | "name": "distilbert.embeddings.LayerNorm.bias", 16 | "aliases": [ 17 | "distilbert.embeddings.LayerNorm.beta" 18 | ] 19 | }, 20 | { 21 | "name": "distilbert.embeddings.LayerNorm.weight", 22 | "aliases": [ 23 | "distilbert.embeddings.LayerNorm.gamma" 24 | ] 25 | } 26 | ], 27 | "post_weights": [ 28 | { 29 | "name": "vocab_transform.weight" 30 | }, 31 | { 32 | "name": "vocab_transform.bias" 33 | }, 34 | { 35 | "name": "vocab_layer_norm.bias" 36 | }, 37 | { 38 | "name": "vocab_layer_norm.weight" 39 | }, 40 | { 41 | "name": "vocab_projector.weight", 42 | "is_embed": true, 43 | "optional": true, 44 | "tied_names": [ 45 | "distilbert.embeddings.word_embeddings.weight" 46 | ] 47 | }, 48 | { 49 | "name": "vocab_projector.bias" 50 | } 51 | ], 52 | "num_layers_config_key": "num_hidden_layers", 53 | "layer_templates": { 54 | "weights": [ 55 | { 56 | "name": "distilbert.transformer.layer.${layer_index}.attention.k_lin.weight" 57 | }, 58 | { 59 | "name": "distilbert.transformer.layer.${layer_index}.attention.k_lin.bias" 60 | }, 61 | { 62 | "name": "distilbert.transformer.layer.${layer_index}.attention.q_lin.weight" 63 | }, 64 | { 65 | "name": "distilbert.transformer.layer.${layer_index}.attention.q_lin.bias" 66 | }, 67 | { 68 | "name": "distilbert.transformer.layer.${layer_index}.attention.v_lin.weight" 69 | }, 70 | { 71 | "name": "distilbert.transformer.layer.${layer_index}.attention.v_lin.bias" 72 | }, 73 | { 74 | "name": "distilbert.transformer.layer.${layer_index}.attention.out_lin.weight" 75 | }, 76 | { 77 | "name": "distilbert.transformer.layer.${layer_index}.attention.out_lin.bias" 78 | }, 79 | { 80 | "name": "distilbert.transformer.layer.${layer_index}.sa_layer_norm.bias" 81 | }, 82 | { 83 | "name": "distilbert.transformer.layer.${layer_index}.sa_layer_norm.weight" 84 | }, 85 | { 86 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin1.weight" 87 | }, 88 | { 89 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin1.bias" 90 | }, 91 | { 92 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin2.weight" 93 | }, 94 | { 95 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin2.bias" 96 | }, 97 | { 98 | "name": "distilbert.transformer.layer.${layer_index}.output_layer_norm.bias" 99 | }, 100 | { 101 | "name": "distilbert.transformer.layer.${layer_index}.output_layer_norm.weight" 102 | } 103 | ] 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/distilbert-sequence-classification.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "distilbert", 3 | "architectures": [ 4 | "DistilBertForSequenceClassification" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "distilbert.embeddings.position_embeddings.weight" 9 | }, 10 | { 11 | "name": "distilbert.embeddings.word_embeddings.weight", 12 | "is_embed": true 13 | }, 14 | { 15 | "name": "distilbert.embeddings.LayerNorm.bias", 16 | "aliases": [ 17 | "distilbert.embeddings.LayerNorm.beta" 18 | ] 19 | }, 20 | { 21 | "name": "distilbert.embeddings.LayerNorm.weight", 22 | "aliases": [ 23 | "distilbert.embeddings.LayerNorm.gamma" 24 | ] 25 | } 26 | ], 27 | "post_weights": [ 28 | { 29 | "name": "classifier.bias" 30 | }, 31 | { 32 | "name": "classifier.weight" 33 | }, 34 | { 35 | "name": "pre_classifier.bias" 36 | }, 37 | { 38 | "name": "pre_classifier.weight" 39 | } 40 | ], 41 | "num_layers_config_key": "num_hidden_layers", 42 | "layer_templates": { 43 | "weights": [ 44 | { 45 | "name": "distilbert.transformer.layer.${layer_index}.attention.k_lin.weight" 46 | }, 47 | { 48 | "name": "distilbert.transformer.layer.${layer_index}.attention.k_lin.bias" 49 | }, 50 | { 51 | "name": "distilbert.transformer.layer.${layer_index}.attention.q_lin.weight" 52 | }, 53 | { 54 | "name": "distilbert.transformer.layer.${layer_index}.attention.q_lin.bias" 55 | }, 56 | { 57 | "name": "distilbert.transformer.layer.${layer_index}.attention.v_lin.weight" 58 | }, 59 | { 60 | "name": "distilbert.transformer.layer.${layer_index}.attention.v_lin.bias" 61 | }, 62 | { 63 | "name": "distilbert.transformer.layer.${layer_index}.attention.out_lin.weight" 64 | }, 65 | { 66 | "name": "distilbert.transformer.layer.${layer_index}.attention.out_lin.bias" 67 | }, 68 | { 69 | "name": "distilbert.transformer.layer.${layer_index}.sa_layer_norm.bias" 70 | }, 71 | { 72 | "name": "distilbert.transformer.layer.${layer_index}.sa_layer_norm.weight" 73 | }, 74 | { 75 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin1.weight" 76 | }, 77 | { 78 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin1.bias" 79 | }, 80 | { 81 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin2.weight" 82 | }, 83 | { 84 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin2.bias" 85 | }, 86 | { 87 | "name": "distilbert.transformer.layer.${layer_index}.output_layer_norm.bias" 88 | }, 89 | { 90 | "name": "distilbert.transformer.layer.${layer_index}.output_layer_norm.weight" 91 | } 92 | ] 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/distilbert-token-classification.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "distilbert", 3 | "architectures": [ 4 | "DistilBertForTokenClassification" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "distilbert.embeddings.position_embeddings.weight" 9 | }, 10 | { 11 | "name": "distilbert.embeddings.word_embeddings.weight", 12 | "is_embed": true 13 | }, 14 | { 15 | "name": "distilbert.embeddings.LayerNorm.bias", 16 | "aliases": [ 17 | "distilbert.embeddings.LayerNorm.beta" 18 | ] 19 | }, 20 | { 21 | "name": "distilbert.embeddings.LayerNorm.weight", 22 | "aliases": [ 23 | "distilbert.embeddings.LayerNorm.gamma" 24 | ] 25 | } 26 | ], 27 | "post_weights": [ 28 | { 29 | "name": "classifier.bias" 30 | }, 31 | { 32 | "name": "classifier.weight" 33 | } 34 | ], 35 | "num_layers_config_key": "num_hidden_layers", 36 | "layer_templates": { 37 | "weights": [ 38 | { 39 | "name": "distilbert.transformer.layer.${layer_index}.attention.k_lin.weight" 40 | }, 41 | { 42 | "name": "distilbert.transformer.layer.${layer_index}.attention.k_lin.bias" 43 | }, 44 | { 45 | "name": "distilbert.transformer.layer.${layer_index}.attention.q_lin.weight" 46 | }, 47 | { 48 | "name": "distilbert.transformer.layer.${layer_index}.attention.q_lin.bias" 49 | }, 50 | { 51 | "name": "distilbert.transformer.layer.${layer_index}.attention.v_lin.weight" 52 | }, 53 | { 54 | "name": "distilbert.transformer.layer.${layer_index}.attention.v_lin.bias" 55 | }, 56 | { 57 | "name": "distilbert.transformer.layer.${layer_index}.attention.out_lin.weight" 58 | }, 59 | { 60 | "name": "distilbert.transformer.layer.${layer_index}.attention.out_lin.bias" 61 | }, 62 | { 63 | "name": "distilbert.transformer.layer.${layer_index}.sa_layer_norm.bias" 64 | }, 65 | { 66 | "name": "distilbert.transformer.layer.${layer_index}.sa_layer_norm.weight" 67 | }, 68 | { 69 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin1.weight" 70 | }, 71 | { 72 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin1.bias" 73 | }, 74 | { 75 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin2.weight" 76 | }, 77 | { 78 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin2.bias" 79 | }, 80 | { 81 | "name": "distilbert.transformer.layer.${layer_index}.output_layer_norm.bias" 82 | }, 83 | { 84 | "name": "distilbert.transformer.layer.${layer_index}.output_layer_norm.weight" 85 | } 86 | ] 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/distilbert.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "distilbert", 3 | "architectures": [ 4 | "DistilBertModel" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "distilbert.embeddings.position_embeddings.weight" 9 | }, 10 | { 11 | "name": "distilbert.embeddings.word_embeddings.weight", 12 | "is_embed": true 13 | }, 14 | { 15 | "name": "distilbert.embeddings.LayerNorm.bias", 16 | "aliases": [ 17 | "distilbert.embeddings.LayerNorm.beta" 18 | ] 19 | }, 20 | { 21 | "name": "distilbert.embeddings.LayerNorm.weight", 22 | "aliases": [ 23 | "distilbert.embeddings.LayerNorm.gamma" 24 | ] 25 | } 26 | ], 27 | "post_weights": [], 28 | "num_layers_config_key": "num_hidden_layers", 29 | "layer_templates": { 30 | "weights": [ 31 | { 32 | "name": "distilbert.transformer.layer.${layer_index}.attention.k_lin.weight" 33 | }, 34 | { 35 | "name": "distilbert.transformer.layer.${layer_index}.attention.k_lin.bias" 36 | }, 37 | { 38 | "name": "distilbert.transformer.layer.${layer_index}.attention.q_lin.weight" 39 | }, 40 | { 41 | "name": "distilbert.transformer.layer.${layer_index}.attention.q_lin.bias" 42 | }, 43 | { 44 | "name": "distilbert.transformer.layer.${layer_index}.attention.v_lin.weight" 45 | }, 46 | { 47 | "name": "distilbert.transformer.layer.${layer_index}.attention.v_lin.bias" 48 | }, 49 | { 50 | "name": "distilbert.transformer.layer.${layer_index}.attention.out_lin.weight" 51 | }, 52 | { 53 | "name": "distilbert.transformer.layer.${layer_index}.attention.out_lin.bias" 54 | }, 55 | { 56 | "name": "distilbert.transformer.layer.${layer_index}.sa_layer_norm.bias" 57 | }, 58 | { 59 | "name": "distilbert.transformer.layer.${layer_index}.sa_layer_norm.weight" 60 | }, 61 | { 62 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin1.weight" 63 | }, 64 | { 65 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin1.bias" 66 | }, 67 | { 68 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin2.weight" 69 | }, 70 | { 71 | "name": "distilbert.transformer.layer.${layer_index}.ffn.lin2.bias" 72 | }, 73 | { 74 | "name": "distilbert.transformer.layer.${layer_index}.output_layer_norm.bias" 75 | }, 76 | { 77 | "name": "distilbert.transformer.layer.${layer_index}.output_layer_norm.weight" 78 | } 79 | ] 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/exaone.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "exaone", 3 | "architectures": [ 4 | "ExaoneForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "transformer.wte.weight", 9 | "is_embed": true, 10 | "output_space": "running_residual" 11 | } 12 | ], 13 | "num_layers_config_key": "num_hidden_layers", 14 | "layer_templates": { 15 | "weights": [ 16 | { 17 | "name": "transformer.h.${layer_index}.ln_1.weight", 18 | "input_space": "running_residual" 19 | }, 20 | { 21 | "name": "transformer.h.${layer_index}.attn.attention.q_proj.weight", 22 | "input_space": "running_residual", 23 | "output_space": "attn_qk_${layer_index}", 24 | "head_split": "output", 25 | "is_kq": true 26 | }, 27 | { 28 | "name": "transformer.h.${layer_index}.attn.attention.k_proj.weight", 29 | "input_space": "running_residual", 30 | "output_space": "attn_qk_${layer_index}", 31 | "head_split": "output", 32 | "is_kq": true 33 | }, 34 | { 35 | "name": "transformer.h.${layer_index}.attn.attention.v_proj.weight", 36 | "input_space": "running_residual", 37 | "output_space": "attn_v_${layer_index}", 38 | "head_split": "output" 39 | }, 40 | { 41 | "name": "transformer.h.${layer_index}.attn.attention.out_proj.weight", 42 | "input_space": "attn_v_${layer_index}", 43 | "output_space": "running_residual", 44 | "head_split": "input" 45 | }, 46 | { 47 | "name": "transformer.h.${layer_index}.ln_2.weight", 48 | "input_space": "running_residual" 49 | }, 50 | { 51 | "name": "transformer.h.${layer_index}.mlp.c_fc_0.weight", 52 | "input_space": "running_residual", 53 | "output_space": "up_${layer_index}" 54 | }, 55 | { 56 | "name": "transformer.h.${layer_index}.mlp.c_fc_1.weight", 57 | "input_space": "running_residual", 58 | "output_space": "up_${layer_index}" 59 | }, 60 | { 61 | "name": "transformer.h.${layer_index}.mlp.c_proj.weight", 62 | "input_space": "up_${layer_index}", 63 | "output_space": "running_residual" 64 | } 65 | ] 66 | }, 67 | "post_weights": [ 68 | { 69 | "name": "transformer.ln_f.weight", 70 | "input_space": "running_residual" 71 | }, 72 | { 73 | "name": "lm_head.weight", 74 | "input_space": "running_residual", 75 | "is_embed": true 76 | } 77 | ] 78 | } 79 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/falcon.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "falcon", 3 | "architectures": [ 4 | "FalconForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "transformer.word_embeddings.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "transformer.ln_f.weight" 15 | }, 16 | { 17 | "name": "transformer.ln_f.bias" 18 | }, 19 | { 20 | "name": "lm_head.weight", 21 | "is_embed": true 22 | } 23 | ], 24 | "num_layers_config_key": "num_hidden_layers", 25 | "layer_templates": { 26 | "weights": [ 27 | { 28 | "name": "transformer.h.${layer_index}.ln_attn.bias" 29 | }, 30 | { 31 | "name": "transformer.h.${layer_index}.ln_attn.weight" 32 | }, 33 | { 34 | "name": "transformer.h.${layer_index}.ln_mlp.bias" 35 | }, 36 | { 37 | "name": "transformer.h.${layer_index}.ln_mlp.weight" 38 | }, 39 | { 40 | "name": "transformer.h.${layer_index}.mlp.dense_4h_to_h.weight" 41 | }, 42 | { 43 | "name": "transformer.h.${layer_index}.mlp.dense_h_to_4h.weight" 44 | }, 45 | { 46 | "name": "transformer.h.${layer_index}.self_attention.dense.weight" 47 | }, 48 | { 49 | "name": "transformer.h.${layer_index}.self_attention.query_key_value.weight" 50 | } 51 | ] 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/gemma.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gemma", 3 | "architectures": [ 4 | "GemmaForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true, 10 | "output_space": "h_0" 11 | } 12 | ], 13 | "num_layers_config_key": "num_hidden_layers", 14 | "layer_templates": { 15 | "weights": [ 16 | { 17 | "name": "model.layers.${layer_index}.input_layernorm.weight", 18 | "input_space": "h_${layer_index}" 19 | }, 20 | { 21 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight", 22 | "input_space": "h_${layer_index}", 23 | "output_space": "attn_qk_${layer_index}" 24 | }, 25 | { 26 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight", 27 | "input_space": "h_${layer_index}", 28 | "output_space": "attn_qk_${layer_index}" 29 | }, 30 | { 31 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight", 32 | "input_space": "h_${layer_index}", 33 | "output_space": "attn_v_${layer_index}" 34 | }, 35 | { 36 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight", 37 | "input_space": "attn_v_${layer_index}", 38 | "output_space": "post_attn_${layer_index}" 39 | }, 40 | { 41 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight", 42 | "input_space": "h_a_${layer_index}" 43 | }, 44 | { 45 | "name": "model.layers.${layer_index}.mlp.up_proj.weight", 46 | "input_space": "h_a_${layer_index}", 47 | "output_space": "up_${layer_index}" 48 | }, 49 | { 50 | "name": "model.layers.${layer_index}.mlp.gate_proj.weight", 51 | "input_space": "h_a_${layer_index}", 52 | "output_space": "up_${layer_index}" 53 | }, 54 | { 55 | "name": "model.layers.${layer_index}.mlp.down_proj.weight", 56 | "input_space": "up_${layer_index}", 57 | "output_space": "post_mlp_${layer_index}" 58 | } 59 | ], 60 | "procedural_spaces": [ 61 | { 62 | "name": "h_a_${layer_index}", 63 | "type": "residual", 64 | "inputs": [ 65 | "h_${layer_index}", 66 | "post_attn_${layer_index}" 67 | ] 68 | }, 69 | { 70 | "name": "h_${layer_index+1}", 71 | "type": "residual", 72 | "inputs": [ 73 | "h_a_${layer_index}", 74 | "post_mlp_${layer_index}" 75 | ] 76 | } 77 | ] 78 | }, 79 | "post_weights": [ 80 | { 81 | "name": "model.norm.weight", 82 | "input_space": "h_${num_layers}" 83 | } 84 | ] 85 | } 86 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/gemma2.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gemma2", 3 | "architectures": [ 4 | "Gemma2ForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "num_layers_config_key": "num_hidden_layers", 13 | "layer_templates": { 14 | "weights": [ 15 | { 16 | "name": "model.layers.${layer_index}.input_layernorm.weight" 17 | }, 18 | { 19 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight" 20 | }, 21 | { 22 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight" 23 | }, 24 | { 25 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight" 26 | }, 27 | { 28 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight" 29 | }, 30 | { 31 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight" 32 | }, 33 | { 34 | "name": "model.layers.${layer_index}.pre_feedforward_layernorm.weight" 35 | }, 36 | { 37 | "name": "model.layers.${layer_index}.mlp.up_proj.weight" 38 | }, 39 | { 40 | "name": "model.layers.${layer_index}.mlp.gate_proj.weight" 41 | }, 42 | { 43 | "name": "model.layers.${layer_index}.mlp.down_proj.weight" 44 | }, 45 | { 46 | "name": "model.layers.${layer_index}.post_feedforward_layernorm.weight" 47 | } 48 | ] 49 | }, 50 | "post_weights": [ 51 | { 52 | "name": "model.norm.weight" 53 | }, 54 | { 55 | "name": "lm_head.weight", 56 | "is_embed": true, 57 | "optional": true, 58 | "tied_names": [ 59 | "model.embed_tokens.weight" 60 | ] 61 | } 62 | ] 63 | } 64 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/gemma3.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gemma3_text", 3 | "architectures": [ 4 | "Gemma3ForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "num_layers_config_key": "num_hidden_layers", 13 | "layer_templates": { 14 | "weights": [ 15 | { 16 | "name": "model.layers.${layer_index}.input_layernorm.weight" 17 | }, 18 | { 19 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight" 20 | }, 21 | { 22 | "name": "model.layers.${layer_index}.self_attn.q_norm.weight" 23 | }, 24 | { 25 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight" 26 | }, 27 | { 28 | "name": "model.layers.${layer_index}.self_attn.k_norm.weight" 29 | }, 30 | { 31 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight" 32 | }, 33 | { 34 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight" 35 | }, 36 | { 37 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight" 38 | }, 39 | { 40 | "name": "model.layers.${layer_index}.pre_feedforward_layernorm.weight" 41 | }, 42 | { 43 | "name": "model.layers.${layer_index}.mlp.up_proj.weight" 44 | }, 45 | { 46 | "name": "model.layers.${layer_index}.mlp.gate_proj.weight" 47 | }, 48 | { 49 | "name": "model.layers.${layer_index}.mlp.down_proj.weight" 50 | }, 51 | { 52 | "name": "model.layers.${layer_index}.post_feedforward_layernorm.weight" 53 | } 54 | ] 55 | }, 56 | "post_weights": [ 57 | { 58 | "name": "model.norm.weight" 59 | }, 60 | { 61 | "name": "lm_head.weight", 62 | "is_embed": true, 63 | "optional": true, 64 | "tied_names": [ 65 | "model.embed_tokens.weight" 66 | ] 67 | } 68 | ] 69 | } 70 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/glm4.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "glm4", 3 | "architectures": [ 4 | "Glm4ForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "num_layers_config_key": "num_hidden_layers", 13 | "layer_templates": { 14 | "weights": [ 15 | { 16 | "name": "model.layers.${layer_index}.input_layernorm.weight" 17 | }, 18 | { 19 | "name": "model.layers.${layer_index}.mlp.down_proj.weight" 20 | }, 21 | { 22 | "name": "model.layers.${layer_index}.mlp.gate_up_proj.weight" 23 | }, 24 | { 25 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight" 26 | }, 27 | { 28 | "name": "model.layers.${layer_index}.post_mlp_layernorm.weight" 29 | }, 30 | { 31 | "name": "model.layers.${layer_index}.post_self_attn_layernorm.weight" 32 | }, 33 | { 34 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight" 35 | }, 36 | { 37 | "name": "model.layers.${layer_index}.self_attn.q_proj.bias", 38 | "optional": true 39 | }, 40 | { 41 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight" 42 | }, 43 | { 44 | "name": "model.layers.${layer_index}.self_attn.k_proj.bias", 45 | "optional": true 46 | }, 47 | { 48 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight" 49 | }, 50 | { 51 | "name": "model.layers.${layer_index}.self_attn.v_proj.bias", 52 | "optional": true 53 | }, 54 | { 55 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight" 56 | } 57 | ] 58 | }, 59 | "post_weights": [ 60 | { 61 | "name": "model.norm.weight" 62 | }, 63 | { 64 | "name": "lm_head.weight", 65 | "is_embed": true, 66 | "optional": true, 67 | "tied_names": [ 68 | "model.embed_tokens.weight" 69 | ] 70 | } 71 | ] 72 | } 73 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/gpt-neox.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt_neox", 3 | "architectures": [ 4 | "GPTNeoXForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "gpt_neox.embed_in.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "gpt_neox.final_layer_norm.bias" 15 | }, 16 | { 17 | "name": "gpt_neox.final_layer_norm.weight" 18 | }, 19 | { 20 | "name": "embed_out.weight", 21 | "is_embed": true 22 | } 23 | ], 24 | "num_layers_config_key": "num_hidden_layers", 25 | "layer_templates": { 26 | "weights": [ 27 | { 28 | "name": "gpt_neox.layers.${layer_index}.attention.dense.weight" 29 | }, 30 | { 31 | "name": "gpt_neox.layers.${layer_index}.attention.dense.bias" 32 | }, 33 | { 34 | "name": "gpt_neox.layers.${layer_index}.attention.query_key_value.weight" 35 | }, 36 | { 37 | "name": "gpt_neox.layers.${layer_index}.attention.query_key_value.bias" 38 | }, 39 | { 40 | "name": "gpt_neox.layers.${layer_index}.input_layernorm.weight" 41 | }, 42 | { 43 | "name": "gpt_neox.layers.${layer_index}.input_layernorm.bias" 44 | }, 45 | { 46 | "name": "gpt_neox.layers.${layer_index}.mlp.dense_4h_to_h.weight" 47 | }, 48 | { 49 | "name": "gpt_neox.layers.${layer_index}.mlp.dense_4h_to_h.bias" 50 | }, 51 | { 52 | "name": "gpt_neox.layers.${layer_index}.mlp.dense_h_to_4h.weight" 53 | }, 54 | { 55 | "name": "gpt_neox.layers.${layer_index}.mlp.dense_h_to_4h.bias" 56 | }, 57 | { 58 | "name": "gpt_neox.layers.${layer_index}.post_attention_layernorm.weight" 59 | }, 60 | { 61 | "name": "gpt_neox.layers.${layer_index}.post_attention_layernorm.bias" 62 | }, 63 | { 64 | "name": "gpt_neox.layers.${layer_index}.attention.bias" 65 | }, 66 | { 67 | "name": "gpt_neox.layers.${layer_index}.attention.masked_bias" 68 | }, 69 | { 70 | "name": "gpt_neox.layers.${layer_index}.attention.rotary_emb.inv_freq" 71 | } 72 | ] 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/gpt2-sequence-classification.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt2", 3 | "architectures": [ 4 | "GPT2ForSequenceClassification" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "transformer.wte.weight" 9 | }, 10 | { 11 | "name": "transformer.wpe.weight" 12 | } 13 | ], 14 | "post_weights": [ 15 | { 16 | "name": "transformer.ln_f.weight" 17 | }, 18 | { 19 | "name": "transformer.ln_f.bias" 20 | }, 21 | { 22 | "name": "score.weight" 23 | } 24 | ], 25 | "num_layers_config_key": "n_layer", 26 | "layer_templates": { 27 | "weights": [ 28 | { 29 | "name": "transformer.h.${layer_index}.attn.c_attn.weight" 30 | }, 31 | { 32 | "name": "transformer.h.${layer_index}.attn.c_attn.bias" 33 | }, 34 | { 35 | "name": "transformer.h.${layer_index}.attn.c_proj.weight" 36 | }, 37 | { 38 | "name": "transformer.h.${layer_index}.attn.c_proj.bias" 39 | }, 40 | { 41 | "name": "transformer.h.${layer_index}.ln_1.weight" 42 | }, 43 | { 44 | "name": "transformer.h.${layer_index}.ln_1.bias" 45 | }, 46 | { 47 | "name": "transformer.h.${layer_index}.ln_2.weight" 48 | }, 49 | { 50 | "name": "transformer.h.${layer_index}.ln_2.bias" 51 | }, 52 | { 53 | "name": "transformer.h.${layer_index}.mlp.c_proj.weight" 54 | }, 55 | { 56 | "name": "transformer.h.${layer_index}.mlp.c_proj.bias" 57 | }, 58 | { 59 | "name": "transformer.h.${layer_index}.mlp.c_fc.weight" 60 | }, 61 | { 62 | "name": "transformer.h.${layer_index}.mlp.c_fc.bias" 63 | } 64 | ] 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/gpt2.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt2", 3 | "architectures": [ 4 | "GPT2LMHeadModel" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "transformer.wte.weight", 9 | "is_embed": true, 10 | "aliases": [ 11 | "wte.weight" 12 | ] 13 | }, 14 | { 15 | "name": "transformer.wpe.weight", 16 | "aliases": [ 17 | "wpe.weight" 18 | ] 19 | } 20 | ], 21 | "post_weights": [ 22 | { 23 | "name": "transformer.ln_f.weight", 24 | "aliases": [ 25 | "ln_f.weight" 26 | ] 27 | }, 28 | { 29 | "name": "transformer.ln_f.bias", 30 | "aliases": [ 31 | "ln_f.bias" 32 | ] 33 | }, 34 | { 35 | "name": "lm_head.weight", 36 | "is_embed": true, 37 | "optional": true, 38 | "tied_names": [ 39 | "transformer.wte.weight", 40 | "wte.weight" 41 | ] 42 | } 43 | ], 44 | "num_layers_config_key": "n_layer", 45 | "layer_templates": { 46 | "weights": [ 47 | { 48 | "name": "transformer.h.${layer_index}.attn.c_attn.weight", 49 | "aliases": [ 50 | "h.${layer_index}.attn.c_attn.weight" 51 | ] 52 | }, 53 | { 54 | "name": "transformer.h.${layer_index}.attn.c_attn.bias", 55 | "aliases": [ 56 | "h.${layer_index}.attn.c_attn.bias" 57 | ] 58 | }, 59 | { 60 | "name": "transformer.h.${layer_index}.attn.c_proj.weight", 61 | "aliases": [ 62 | "h.${layer_index}.attn.c_proj.weight" 63 | ] 64 | }, 65 | { 66 | "name": "transformer.h.${layer_index}.attn.c_proj.bias", 67 | "aliases": [ 68 | "h.${layer_index}.attn.c_proj.bias" 69 | ] 70 | }, 71 | { 72 | "name": "transformer.h.${layer_index}.ln_1.weight", 73 | "aliases": [ 74 | "h.${layer_index}.ln_1.weight" 75 | ] 76 | }, 77 | { 78 | "name": "transformer.h.${layer_index}.ln_1.bias", 79 | "aliases": [ 80 | "h.${layer_index}.ln_1.bias" 81 | ] 82 | }, 83 | { 84 | "name": "transformer.h.${layer_index}.ln_2.weight", 85 | "aliases": [ 86 | "h.${layer_index}.ln_2.weight" 87 | ] 88 | }, 89 | { 90 | "name": "transformer.h.${layer_index}.ln_2.bias", 91 | "aliases": [ 92 | "h.${layer_index}.ln_2.bias" 93 | ] 94 | }, 95 | { 96 | "name": "transformer.h.${layer_index}.mlp.c_proj.weight", 97 | "aliases": [ 98 | "h.${layer_index}.mlp.c_proj.weight" 99 | ] 100 | }, 101 | { 102 | "name": "transformer.h.${layer_index}.mlp.c_proj.bias", 103 | "aliases": [ 104 | "h.${layer_index}.mlp.c_proj.bias" 105 | ] 106 | }, 107 | { 108 | "name": "transformer.h.${layer_index}.mlp.c_fc.weight", 109 | "aliases": [ 110 | "h.${layer_index}.mlp.c_fc.weight" 111 | ] 112 | }, 113 | { 114 | "name": "transformer.h.${layer_index}.mlp.c_fc.bias", 115 | "aliases": [ 116 | "h.${layer_index}.mlp.c_fc.bias" 117 | ] 118 | } 119 | ] 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/gptbigcode.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "gpt_bigcode", 3 | "architectures": [ 4 | "GPTBigCodeForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "transformer.wte.weight", 9 | "is_embed": true 10 | }, 11 | { 12 | "name": "transformer.wpe.weight" 13 | } 14 | ], 15 | "post_weights": [ 16 | { 17 | "name": "transformer.ln_f.weight" 18 | }, 19 | { 20 | "name": "transformer.ln_f.bias" 21 | }, 22 | { 23 | "name": "lm_head.weight", 24 | "is_embed": true, 25 | "optional": true, 26 | "tied_names": [ 27 | "transformer.wte.weight" 28 | ] 29 | } 30 | ], 31 | "num_layers_config_key": "n_layer", 32 | "layer_templates": { 33 | "weights": [ 34 | { 35 | "name": "transformer.h.${layer_index}.attn.c_attn.weight" 36 | }, 37 | { 38 | "name": "transformer.h.${layer_index}.attn.c_attn.bias" 39 | }, 40 | { 41 | "name": "transformer.h.${layer_index}.attn.c_proj.weight" 42 | }, 43 | { 44 | "name": "transformer.h.${layer_index}.attn.c_proj.bias" 45 | }, 46 | { 47 | "name": "transformer.h.${layer_index}.ln_1.weight" 48 | }, 49 | { 50 | "name": "transformer.h.${layer_index}.ln_1.bias" 51 | }, 52 | { 53 | "name": "transformer.h.${layer_index}.ln_2.weight" 54 | }, 55 | { 56 | "name": "transformer.h.${layer_index}.ln_2.bias" 57 | }, 58 | { 59 | "name": "transformer.h.${layer_index}.mlp.c_proj.weight" 60 | }, 61 | { 62 | "name": "transformer.h.${layer_index}.mlp.c_proj.bias" 63 | }, 64 | { 65 | "name": "transformer.h.${layer_index}.mlp.c_fc.weight" 66 | }, 67 | { 68 | "name": "transformer.h.${layer_index}.mlp.c_fc.bias" 69 | } 70 | ] 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/internlm2.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "internlm2", 3 | "architectures": [ 4 | "InternLM2ForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.tok_embeddings.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "model.norm.weight" 15 | }, 16 | { 17 | "name": "output.weight", 18 | "is_embed": true, 19 | "optional": true, 20 | "tied_names": [ 21 | "model.tok_embeddings.weight" 22 | ] 23 | } 24 | ], 25 | "num_layers_config_key": "num_hidden_layers", 26 | "layer_templates": { 27 | "weights": [ 28 | { 29 | "name": "model.layers.${layer_index}.attention_norm.weight" 30 | }, 31 | { 32 | "name": "model.layers.${layer_index}.ffn_norm.weight" 33 | }, 34 | { 35 | "name": "model.layers.${layer_index}.attention.wqkv.weight" 36 | }, 37 | { 38 | "name": "model.layers.${layer_index}.attention.wo.weight" 39 | }, 40 | { 41 | "name": "model.layers.${layer_index}.feed_forward.w1.weight" 42 | }, 43 | { 44 | "name": "model.layers.${layer_index}.feed_forward.w2.weight" 45 | }, 46 | { 47 | "name": "model.layers.${layer_index}.feed_forward.w3.weight" 48 | } 49 | ] 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/jais.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "jais", 3 | "architectures": [ 4 | "JAISLMHeadModel" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "transformer.wte.weight", 9 | "is_embed": true 10 | }, 11 | { 12 | "name": "transformer.relative_pe.slopes" 13 | } 14 | ], 15 | "post_weights": [ 16 | { 17 | "name": "transformer.ln_f.weight" 18 | }, 19 | { 20 | "name": "transformer.ln_f.bias" 21 | } 22 | ], 23 | "num_layers_config_key": "n_layer", 24 | "layer_templates": { 25 | "weights": [ 26 | { 27 | "name": "transformer.h.${layer_index}.attn.c_attn.weight" 28 | }, 29 | { 30 | "name": "transformer.h.${layer_index}.attn.c_attn.bias" 31 | }, 32 | { 33 | "name": "transformer.h.${layer_index}.attn.c_proj.weight" 34 | }, 35 | { 36 | "name": "transformer.h.${layer_index}.attn.c_proj.bias" 37 | }, 38 | { 39 | "name": "transformer.h.${layer_index}.ln_1.weight" 40 | }, 41 | { 42 | "name": "transformer.h.${layer_index}.ln_1.bias" 43 | }, 44 | { 45 | "name": "transformer.h.${layer_index}.ln_2.weight" 46 | }, 47 | { 48 | "name": "transformer.h.${layer_index}.ln_2.bias" 49 | }, 50 | { 51 | "name": "transformer.h.${layer_index}.mlp.c_fc.weight" 52 | }, 53 | { 54 | "name": "transformer.h.${layer_index}.mlp.c_fc.bias" 55 | }, 56 | { 57 | "name": "transformer.h.${layer_index}.mlp.c_fc2.weight" 58 | }, 59 | { 60 | "name": "transformer.h.${layer_index}.mlp.c_fc2.bias" 61 | }, 62 | { 63 | "name": "transformer.h.${layer_index}.mlp.c_proj.weight" 64 | }, 65 | { 66 | "name": "transformer.h.${layer_index}.mlp.c_proj.bias" 67 | } 68 | ] 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/llama.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "llama", 3 | "architectures": [ 4 | "LlamaForCausalLM", 5 | "LLaMaForCausalLM" 6 | ], 7 | "pre_weights": [ 8 | { 9 | "name": "model.embed_tokens.weight", 10 | "is_embed": true, 11 | "output_space": "running_residual" 12 | } 13 | ], 14 | "num_layers_config_key": "num_hidden_layers", 15 | "layer_templates": { 16 | "weights": [ 17 | { 18 | "name": "model.layers.${layer_index}.input_layernorm.weight", 19 | "input_space": "running_residual" 20 | }, 21 | { 22 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight", 23 | "input_space": "running_residual", 24 | "output_space": "attn_qk_${layer_index}", 25 | "head_split": "output", 26 | "is_kq": true 27 | }, 28 | { 29 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight", 30 | "input_space": "running_residual", 31 | "output_space": "attn_qk_${layer_index}", 32 | "head_split": "output", 33 | "is_kq": true 34 | }, 35 | { 36 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight", 37 | "input_space": "running_residual", 38 | "output_space": "attn_v_${layer_index}", 39 | "head_split": "output" 40 | }, 41 | { 42 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight", 43 | "input_space": "attn_v_${layer_index}", 44 | "output_space": "running_residual", 45 | "head_split": "input" 46 | }, 47 | { 48 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight", 49 | "input_space": "running_residual" 50 | }, 51 | { 52 | "name": "model.layers.${layer_index}.mlp.up_proj.weight", 53 | "input_space": "running_residual", 54 | "output_space": "up_${layer_index}" 55 | }, 56 | { 57 | "name": "model.layers.${layer_index}.mlp.gate_proj.weight", 58 | "input_space": "running_residual", 59 | "output_space": "up_${layer_index}" 60 | }, 61 | { 62 | "name": "model.layers.${layer_index}.mlp.down_proj.weight", 63 | "input_space": "up_${layer_index}", 64 | "output_space": "running_residual" 65 | } 66 | ] 67 | }, 68 | "post_weights": [ 69 | { 70 | "name": "model.norm.weight", 71 | "input_space": "running_residual" 72 | }, 73 | { 74 | "name": "lm_head.weight", 75 | "input_space": "running_residual", 76 | "is_embed": true, 77 | "optional": true, 78 | "tied_names": [ 79 | "model.embed_tokens.weight" 80 | ] 81 | } 82 | ] 83 | } 84 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/mamba.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "mamba", 3 | "architectures": [ 4 | "MambaForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "backbone.embeddings.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "backbone.norm_f.weight" 15 | }, 16 | { 17 | "name": "lm_head.weight", 18 | "is_embed": true, 19 | "optional": true, 20 | "tied_names": [ 21 | "backbone.embeddings.weight" 22 | ] 23 | } 24 | ], 25 | "num_layers_config_key": "num_hidden_layers", 26 | "layer_templates": { 27 | "weights": [ 28 | { 29 | "name": "backbone.layers.${layer_index}.mixer.A_log" 30 | }, 31 | { 32 | "name": "backbone.layers.${layer_index}.mixer.conv1d.bias" 33 | }, 34 | { 35 | "name": "backbone.layers.${layer_index}.mixer.conv1d.weight" 36 | }, 37 | { 38 | "name": "backbone.layers.${layer_index}.mixer.D" 39 | }, 40 | { 41 | "name": "backbone.layers.${layer_index}.mixer.dt_proj.bias" 42 | }, 43 | { 44 | "name": "backbone.layers.${layer_index}.mixer.dt_proj.weight" 45 | }, 46 | { 47 | "name": "backbone.layers.${layer_index}.mixer.in_proj.weight" 48 | }, 49 | { 50 | "name": "backbone.layers.${layer_index}.mixer.out_proj.weight" 51 | }, 52 | { 53 | "name": "backbone.layers.${layer_index}.mixer.x_proj.weight" 54 | }, 55 | { 56 | "name": "backbone.layers.${layer_index}.norm.weight" 57 | } 58 | ] 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/mistral.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "mistral", 3 | "architectures": [ 4 | "MistralForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "num_layers_config_key": "num_hidden_layers", 13 | "layer_templates": { 14 | "weights": [ 15 | { 16 | "name": "model.layers.${layer_index}.input_layernorm.weight" 17 | }, 18 | { 19 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight" 20 | }, 21 | { 22 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight" 23 | }, 24 | { 25 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight" 26 | }, 27 | { 28 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight" 29 | }, 30 | { 31 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight" 32 | }, 33 | { 34 | "name": "model.layers.${layer_index}.mlp.up_proj.weight" 35 | }, 36 | { 37 | "name": "model.layers.${layer_index}.mlp.gate_proj.weight" 38 | }, 39 | { 40 | "name": "model.layers.${layer_index}.mlp.down_proj.weight" 41 | } 42 | ] 43 | }, 44 | "post_weights": [ 45 | { 46 | "name": "model.norm.weight" 47 | }, 48 | { 49 | "name": "lm_head.weight", 50 | "is_embed": true, 51 | "optional": true, 52 | "tied_names": [ 53 | "model.embed_tokens.weight" 54 | ] 55 | } 56 | ] 57 | } 58 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/olmo2.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "olmo2", 3 | "architectures": [ 4 | "Olmo2ForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "num_layers_config_key": "num_hidden_layers", 13 | "layer_templates": { 14 | "weights": [ 15 | { 16 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight" 17 | }, 18 | { 19 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight" 20 | }, 21 | { 22 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight" 23 | }, 24 | { 25 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight" 26 | }, 27 | { 28 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight" 29 | }, 30 | { 31 | "name": "model.layers.${layer_index}.mlp.up_proj.weight" 32 | }, 33 | { 34 | "name": "model.layers.${layer_index}.mlp.gate_proj.weight" 35 | }, 36 | { 37 | "name": "model.layers.${layer_index}.mlp.down_proj.weight" 38 | }, 39 | { 40 | "name": "model.layers.${layer_index}.post_feedforward_layernorm.weight" 41 | } 42 | ] 43 | }, 44 | "post_weights": [ 45 | { 46 | "name": "model.norm.weight" 47 | }, 48 | { 49 | "name": "lm_head.weight", 50 | "is_embed": true, 51 | "optional": true, 52 | "tied_names": [ 53 | "model.embed_tokens.weight" 54 | ] 55 | } 56 | ] 57 | } 58 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/phi-1.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "mixformer-sequential", 3 | "architectures": [ 4 | "MixFormerSequentialForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "layers.0.wte.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "num_layers_config_key": "n_layer", 13 | "layer_templates": { 14 | "weights": [ 15 | { 16 | "name": "layers.${layer_index}.ln.bias" 17 | }, 18 | { 19 | "name": "layers.${layer_index}.ln.weight" 20 | }, 21 | { 22 | "name": "layers.${layer_index}.mixer.Wqkv.bias" 23 | }, 24 | { 25 | "name": "layers.${layer_index}.mixer.Wqkv.weight" 26 | }, 27 | { 28 | "name": "layers.${layer_index}.mixer.out_proj.bias" 29 | }, 30 | { 31 | "name": "layers.${layer_index}.mixer.out_proj.weight" 32 | }, 33 | { 34 | "name": "layers.${layer_index}.mixer.rotary_emb.inv_freq" 35 | }, 36 | { 37 | "name": "layers.${layer_index}.mlp.fc1.bias" 38 | }, 39 | { 40 | "name": "layers.${layer_index}.mlp.fc1.weight" 41 | }, 42 | { 43 | "name": "layers.${layer_index}.mlp.fc2.bias" 44 | }, 45 | { 46 | "name": "layers.${layer_index}.mlp.fc2.weight" 47 | } 48 | ] 49 | }, 50 | "post_weights": [ 51 | { 52 | "name": "layers.${num_layers}.linear.bias", 53 | "is_embed": true 54 | }, 55 | { 56 | "name": "layers.${num_layers}.linear.weight", 57 | "is_embed": true 58 | }, 59 | { 60 | "name": "layers.${num_layers}.ln.bias" 61 | }, 62 | { 63 | "name": "layers.${num_layers}.ln.weight" 64 | } 65 | ] 66 | } 67 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/phi2-old.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "phi-msft", 3 | "architectures": [ 4 | "PhiForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "transformer.embd.wte.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "lm_head.linear.bias" 15 | }, 16 | { 17 | "name": "lm_head.linear.weight", 18 | "is_embed": true 19 | }, 20 | { 21 | "name": "lm_head.ln.bias" 22 | }, 23 | { 24 | "name": "lm_head.ln.weight" 25 | } 26 | ], 27 | "num_layers_config_key": "n_layer", 28 | "layer_templates": { 29 | "weights": [ 30 | { 31 | "name": "transformer.h.${layer_index}.ln.bias" 32 | }, 33 | { 34 | "name": "transformer.h.${layer_index}.ln.weight" 35 | }, 36 | { 37 | "name": "transformer.h.${layer_index}.mixer.out_proj.bias" 38 | }, 39 | { 40 | "name": "transformer.h.${layer_index}.mixer.out_proj.weight" 41 | }, 42 | { 43 | "name": "transformer.h.${layer_index}.mixer.Wqkv.bias" 44 | }, 45 | { 46 | "name": "transformer.h.${layer_index}.mixer.Wqkv.weight" 47 | }, 48 | { 49 | "name": "transformer.h.${layer_index}.mlp.fc1.bias" 50 | }, 51 | { 52 | "name": "transformer.h.${layer_index}.mlp.fc1.weight" 53 | }, 54 | { 55 | "name": "transformer.h.${layer_index}.mlp.fc2.bias" 56 | }, 57 | { 58 | "name": "transformer.h.${layer_index}.mlp.fc2.weight" 59 | } 60 | ] 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/phi2.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "phi", 3 | "architectures": [ 4 | "PhiForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "lm_head.bias" 15 | }, 16 | { 17 | "name": "lm_head.weight", 18 | "is_embed": true 19 | }, 20 | { 21 | "name": "model.final_layernorm.bias" 22 | }, 23 | { 24 | "name": "model.final_layernorm.weight" 25 | } 26 | ], 27 | "num_layers_config_key": "num_hidden_layers", 28 | "layer_templates": { 29 | "weights": [ 30 | { 31 | "name": "model.layers.${layer_index}.input_layernorm.bias" 32 | }, 33 | { 34 | "name": "model.layers.${layer_index}.input_layernorm.weight" 35 | }, 36 | { 37 | "name": "model.layers.${layer_index}.self_attn.dense.bias" 38 | }, 39 | { 40 | "name": "model.layers.${layer_index}.self_attn.dense.weight" 41 | }, 42 | { 43 | "name": "model.layers.${layer_index}.self_attn.q_proj.bias" 44 | }, 45 | { 46 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight" 47 | }, 48 | { 49 | "name": "model.layers.${layer_index}.self_attn.k_proj.bias" 50 | }, 51 | { 52 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight" 53 | }, 54 | { 55 | "name": "model.layers.${layer_index}.self_attn.v_proj.bias" 56 | }, 57 | { 58 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight" 59 | }, 60 | { 61 | "name": "model.layers.${layer_index}.mlp.fc1.bias" 62 | }, 63 | { 64 | "name": "model.layers.${layer_index}.mlp.fc1.weight" 65 | }, 66 | { 67 | "name": "model.layers.${layer_index}.mlp.fc2.bias" 68 | }, 69 | { 70 | "name": "model.layers.${layer_index}.mlp.fc2.weight" 71 | } 72 | ] 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/phi3-small.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "phi3small", 3 | "architectures": [ 4 | "Phi3SmallForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "lm_head.weight", 15 | "is_embed": true, 16 | "optional": true, 17 | "tied_names": [ 18 | "model.embed_tokens.weight" 19 | ] 20 | }, 21 | { 22 | "name": "model.final_layernorm.weight" 23 | }, 24 | { 25 | "name": "model.final_layernorm.bias" 26 | } 27 | ], 28 | "num_layers_config_key": "num_hidden_layers", 29 | "layer_templates": { 30 | "weights": [ 31 | { 32 | "name": "model.layers.${layer_index}.input_layernorm.weight" 33 | }, 34 | { 35 | "name": "model.layers.${layer_index}.input_layernorm.bias" 36 | }, 37 | { 38 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight" 39 | }, 40 | { 41 | "name": "model.layers.${layer_index}.post_attention_layernorm.bias" 42 | }, 43 | { 44 | "name": "model.layers.${layer_index}.self_attn.dense.weight" 45 | }, 46 | { 47 | "name": "model.layers.${layer_index}.self_attn.dense.bias" 48 | }, 49 | { 50 | "name": "model.layers.${layer_index}.self_attn.query_key_value.weight" 51 | }, 52 | { 53 | "name": "model.layers.${layer_index}.self_attn.query_key_value.bias" 54 | }, 55 | { 56 | "name": "model.layers.${layer_index}.mlp.up_proj.weight" 57 | }, 58 | { 59 | "name": "model.layers.${layer_index}.mlp.up_proj.bias" 60 | }, 61 | { 62 | "name": "model.layers.${layer_index}.mlp.down_proj.weight" 63 | }, 64 | { 65 | "name": "model.layers.${layer_index}.mlp.down_proj.bias" 66 | } 67 | ] 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/phi3.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "phi3", 3 | "architectures": [ 4 | "Phi3ForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "lm_head.weight", 15 | "is_embed": true 16 | }, 17 | { 18 | "name": "model.norm.weight" 19 | } 20 | ], 21 | "num_layers_config_key": "num_hidden_layers", 22 | "layer_templates": { 23 | "weights": [ 24 | { 25 | "name": "model.layers.${layer_index}.input_layernorm.weight" 26 | }, 27 | { 28 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight" 29 | }, 30 | { 31 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight" 32 | }, 33 | { 34 | "name": "model.layers.${layer_index}.self_attn.qkv_proj.weight" 35 | }, 36 | { 37 | "name": "model.layers.${layer_index}.mlp.gate_up_proj.weight" 38 | }, 39 | { 40 | "name": "model.layers.${layer_index}.mlp.down_proj.weight" 41 | } 42 | ] 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/qwen.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "qwen", 3 | "architectures": [ 4 | "QWenLMHeadModel" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "transformer.wte.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "transformer.ln_f.weight" 15 | }, 16 | { 17 | "name": "lm_head.weight", 18 | "is_embed": true 19 | } 20 | ], 21 | "num_layers_config_key": "num_hidden_layers", 22 | "layer_templates": { 23 | "weights": [ 24 | { 25 | "name": "transformer.h.${layer_index}.attn.c_attn.bias" 26 | }, 27 | { 28 | "name": "transformer.h.${layer_index}.attn.c_attn.weight" 29 | }, 30 | { 31 | "name": "transformer.h.${layer_index}.attn.c_proj.weight" 32 | }, 33 | { 34 | "name": "transformer.h.${layer_index}.ln_1.weight" 35 | }, 36 | { 37 | "name": "transformer.h.${layer_index}.ln_2.weight" 38 | }, 39 | { 40 | "name": "transformer.h.${layer_index}.mlp.c_proj.weight" 41 | }, 42 | { 43 | "name": "transformer.h.${layer_index}.mlp.w1.weight" 44 | }, 45 | { 46 | "name": "transformer.h.${layer_index}.mlp.w2.weight" 47 | } 48 | ] 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/qwen2.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "qwen2", 3 | "architectures": [ 4 | "Qwen2ForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "model.norm.weight" 15 | }, 16 | { 17 | "name": "lm_head.weight", 18 | "is_embed": true, 19 | "optional": true, 20 | "tied_names": [ 21 | "model.embed_tokens.weight" 22 | ] 23 | } 24 | ], 25 | "num_layers_config_key": "num_hidden_layers", 26 | "layer_templates": { 27 | "weights": [ 28 | { 29 | "name": "model.layers.${layer_index}.input_layernorm.weight" 30 | }, 31 | { 32 | "name": "model.layers.${layer_index}.mlp.down_proj.weight" 33 | }, 34 | { 35 | "name": "model.layers.${layer_index}.mlp.gate_proj.weight" 36 | }, 37 | { 38 | "name": "model.layers.${layer_index}.mlp.up_proj.weight" 39 | }, 40 | { 41 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight" 42 | }, 43 | { 44 | "name": "model.layers.${layer_index}.self_attn.k_proj.bias" 45 | }, 46 | { 47 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight" 48 | }, 49 | { 50 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight" 51 | }, 52 | { 53 | "name": "model.layers.${layer_index}.self_attn.q_proj.bias" 54 | }, 55 | { 56 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight" 57 | }, 58 | { 59 | "name": "model.layers.${layer_index}.self_attn.v_proj.bias" 60 | }, 61 | { 62 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight" 63 | } 64 | ] 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/qwen3.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "qwen3", 3 | "architectures": [ 4 | "Qwen3ForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "model.norm.weight" 15 | }, 16 | { 17 | "name": "lm_head.weight", 18 | "is_embed": true, 19 | "optional": true, 20 | "tied_names": [ 21 | "model.embed_tokens.weight" 22 | ] 23 | } 24 | ], 25 | "num_layers_config_key": "num_hidden_layers", 26 | "layer_templates": { 27 | "weights": [ 28 | { 29 | "name": "model.layers.${layer_index}.input_layernorm.weight" 30 | }, 31 | { 32 | "name": "model.layers.${layer_index}.mlp.down_proj.weight" 33 | }, 34 | { 35 | "name": "model.layers.${layer_index}.mlp.gate_proj.weight" 36 | }, 37 | { 38 | "name": "model.layers.${layer_index}.mlp.up_proj.weight" 39 | }, 40 | { 41 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight" 42 | }, 43 | { 44 | "name": "model.layers.${layer_index}.self_attn.k_norm.weight" 45 | }, 46 | { 47 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight" 48 | }, 49 | { 50 | "name": "model.layers.${layer_index}.self_attn.q_norm.weight" 51 | }, 52 | { 53 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight" 54 | }, 55 | { 56 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight" 57 | }, 58 | { 59 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight" 60 | } 61 | ] 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/roberta-masked-lm.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "roberta", 3 | "architectures": [ 4 | "RobertaForMaskedLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "roberta.embeddings.position_embeddings.weight" 9 | }, 10 | { 11 | "name": "roberta.embeddings.word_embeddings.weight", 12 | "is_embed": true 13 | }, 14 | { 15 | "name": "roberta.embeddings.token_type_embeddings.weight" 16 | }, 17 | { 18 | "name": "roberta.embeddings.LayerNorm.weight" 19 | }, 20 | { 21 | "name": "roberta.embeddings.LayerNorm.bias" 22 | }, 23 | { 24 | "name": "roberta.embeddings.position_ids", 25 | "optional": true, 26 | "force_dtype": "int64" 27 | } 28 | ], 29 | "post_weights": [ 30 | { 31 | "name": "lm_head.bias" 32 | }, 33 | { 34 | "name": "lm_head.dense.weight" 35 | }, 36 | { 37 | "name": "lm_head.dense.bias" 38 | }, 39 | { 40 | "name": "lm_head.layer_norm.weight" 41 | }, 42 | { 43 | "name": "lm_head.layer_norm.bias" 44 | }, 45 | { 46 | "name": "lm_head.decoder.weight", 47 | "is_embed": true, 48 | "optional": true, 49 | "tied_names": [ 50 | "roberta.embeddings.word_embeddings.weight" 51 | ] 52 | } 53 | ], 54 | "num_layers_config_key": "num_hidden_layers", 55 | "layer_templates": { 56 | "weights": [ 57 | { 58 | "name": "roberta.encoder.layer.${layer_index}.attention.output.dense.weight" 59 | }, 60 | { 61 | "name": "roberta.encoder.layer.${layer_index}.attention.output.dense.bias" 62 | }, 63 | { 64 | "name": "roberta.encoder.layer.${layer_index}.attention.output.LayerNorm.weight" 65 | }, 66 | { 67 | "name": "roberta.encoder.layer.${layer_index}.attention.output.LayerNorm.bias" 68 | }, 69 | { 70 | "name": "roberta.encoder.layer.${layer_index}.attention.self.query.weight" 71 | }, 72 | { 73 | "name": "roberta.encoder.layer.${layer_index}.attention.self.query.bias" 74 | }, 75 | { 76 | "name": "roberta.encoder.layer.${layer_index}.attention.self.key.weight" 77 | }, 78 | { 79 | "name": "roberta.encoder.layer.${layer_index}.attention.self.key.bias" 80 | }, 81 | { 82 | "name": "roberta.encoder.layer.${layer_index}.attention.self.value.weight" 83 | }, 84 | { 85 | "name": "roberta.encoder.layer.${layer_index}.attention.self.value.bias" 86 | }, 87 | { 88 | "name": "roberta.encoder.layer.${layer_index}.intermediate.dense.weight" 89 | }, 90 | { 91 | "name": "roberta.encoder.layer.${layer_index}.intermediate.dense.bias" 92 | }, 93 | { 94 | "name": "roberta.encoder.layer.${layer_index}.output.dense.weight" 95 | }, 96 | { 97 | "name": "roberta.encoder.layer.${layer_index}.output.dense.bias" 98 | }, 99 | { 100 | "name": "roberta.encoder.layer.${layer_index}.output.LayerNorm.weight" 101 | }, 102 | { 103 | "name": "roberta.encoder.layer.${layer_index}.output.LayerNorm.bias" 104 | } 105 | ] 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/roberta-sequence-classification.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "roberta", 3 | "architectures": [ 4 | "RobertaForSequenceClassification" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "roberta.embeddings.position_embeddings.weight" 9 | }, 10 | { 11 | "name": "roberta.embeddings.word_embeddings.weight" 12 | }, 13 | { 14 | "name": "roberta.embeddings.token_type_embeddings.weight" 15 | }, 16 | { 17 | "name": "roberta.embeddings.LayerNorm.weight" 18 | }, 19 | { 20 | "name": "roberta.embeddings.LayerNorm.bias" 21 | }, 22 | { 23 | "name": "roberta.embeddings.position_ids", 24 | "optional": true, 25 | "force_dtype": "int64" 26 | } 27 | ], 28 | "post_weights": [ 29 | { 30 | "name": "classifier.dense.weight" 31 | }, 32 | { 33 | "name": "classifier.dense.bias" 34 | }, 35 | { 36 | "name": "classifier.out_proj.weight" 37 | }, 38 | { 39 | "name": "classifier.out_proj.bias" 40 | } 41 | ], 42 | "num_layers_config_key": "num_hidden_layers", 43 | "layer_templates": { 44 | "weights": [ 45 | { 46 | "name": "roberta.encoder.layer.${layer_index}.attention.output.dense.weight" 47 | }, 48 | { 49 | "name": "roberta.encoder.layer.${layer_index}.attention.output.dense.bias" 50 | }, 51 | { 52 | "name": "roberta.encoder.layer.${layer_index}.attention.output.LayerNorm.weight" 53 | }, 54 | { 55 | "name": "roberta.encoder.layer.${layer_index}.attention.output.LayerNorm.bias" 56 | }, 57 | { 58 | "name": "roberta.encoder.layer.${layer_index}.attention.self.query.weight" 59 | }, 60 | { 61 | "name": "roberta.encoder.layer.${layer_index}.attention.self.query.bias" 62 | }, 63 | { 64 | "name": "roberta.encoder.layer.${layer_index}.attention.self.key.weight" 65 | }, 66 | { 67 | "name": "roberta.encoder.layer.${layer_index}.attention.self.key.bias" 68 | }, 69 | { 70 | "name": "roberta.encoder.layer.${layer_index}.attention.self.value.weight" 71 | }, 72 | { 73 | "name": "roberta.encoder.layer.${layer_index}.attention.self.value.bias" 74 | }, 75 | { 76 | "name": "roberta.encoder.layer.${layer_index}.intermediate.dense.weight" 77 | }, 78 | { 79 | "name": "roberta.encoder.layer.${layer_index}.intermediate.dense.bias" 80 | }, 81 | { 82 | "name": "roberta.encoder.layer.${layer_index}.output.dense.weight" 83 | }, 84 | { 85 | "name": "roberta.encoder.layer.${layer_index}.output.dense.bias" 86 | }, 87 | { 88 | "name": "roberta.encoder.layer.${layer_index}.output.LayerNorm.weight" 89 | }, 90 | { 91 | "name": "roberta.encoder.layer.${layer_index}.output.LayerNorm.bias" 92 | } 93 | ] 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/roberta-token-classification.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "roberta", 3 | "architectures": [ 4 | "RobertaForTokenClassification" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "roberta.embeddings.position_embeddings.weight" 9 | }, 10 | { 11 | "name": "roberta.embeddings.word_embeddings.weight" 12 | }, 13 | { 14 | "name": "roberta.embeddings.token_type_embeddings.weight" 15 | }, 16 | { 17 | "name": "roberta.embeddings.LayerNorm.weight" 18 | }, 19 | { 20 | "name": "roberta.embeddings.LayerNorm.bias" 21 | }, 22 | { 23 | "name": "roberta.embeddings.position_ids", 24 | "optional": true, 25 | "force_dtype": "int64" 26 | } 27 | ], 28 | "post_weights": [ 29 | { 30 | "name": "classifier.weight" 31 | }, 32 | { 33 | "name": "classifier.bias" 34 | } 35 | ], 36 | "num_layers_config_key": "num_hidden_layers", 37 | "layer_templates": { 38 | "weights": [ 39 | { 40 | "name": "roberta.encoder.layer.${layer_index}.attention.output.dense.weight" 41 | }, 42 | { 43 | "name": "roberta.encoder.layer.${layer_index}.attention.output.dense.bias" 44 | }, 45 | { 46 | "name": "roberta.encoder.layer.${layer_index}.attention.output.LayerNorm.weight" 47 | }, 48 | { 49 | "name": "roberta.encoder.layer.${layer_index}.attention.output.LayerNorm.bias" 50 | }, 51 | { 52 | "name": "roberta.encoder.layer.${layer_index}.attention.self.query.weight" 53 | }, 54 | { 55 | "name": "roberta.encoder.layer.${layer_index}.attention.self.query.bias" 56 | }, 57 | { 58 | "name": "roberta.encoder.layer.${layer_index}.attention.self.key.weight" 59 | }, 60 | { 61 | "name": "roberta.encoder.layer.${layer_index}.attention.self.key.bias" 62 | }, 63 | { 64 | "name": "roberta.encoder.layer.${layer_index}.attention.self.value.weight" 65 | }, 66 | { 67 | "name": "roberta.encoder.layer.${layer_index}.attention.self.value.bias" 68 | }, 69 | { 70 | "name": "roberta.encoder.layer.${layer_index}.intermediate.dense.weight" 71 | }, 72 | { 73 | "name": "roberta.encoder.layer.${layer_index}.intermediate.dense.bias" 74 | }, 75 | { 76 | "name": "roberta.encoder.layer.${layer_index}.output.dense.weight" 77 | }, 78 | { 79 | "name": "roberta.encoder.layer.${layer_index}.output.dense.bias" 80 | }, 81 | { 82 | "name": "roberta.encoder.layer.${layer_index}.output.LayerNorm.weight" 83 | }, 84 | { 85 | "name": "roberta.encoder.layer.${layer_index}.output.LayerNorm.bias" 86 | } 87 | ] 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/roberta.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "roberta", 3 | "architectures": [ 4 | "RobertaModel" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "roberta.embeddings.position_embeddings.weight" 9 | }, 10 | { 11 | "name": "roberta.embeddings.word_embeddings.weight" 12 | }, 13 | { 14 | "name": "roberta.embeddings.token_type_embeddings.weight" 15 | }, 16 | { 17 | "name": "roberta.embeddings.LayerNorm.weight" 18 | }, 19 | { 20 | "name": "roberta.embeddings.LayerNorm.bias" 21 | }, 22 | { 23 | "name": "roberta.embeddings.position_ids", 24 | "optional": true, 25 | "force_dtype": "int64" 26 | } 27 | ], 28 | "post_weights": [ 29 | { 30 | "name": "pooler.dense.weight" 31 | }, 32 | { 33 | "name": "pooler.dense.bias" 34 | } 35 | ], 36 | "num_layers_config_key": "num_hidden_layers", 37 | "layer_templates": { 38 | "weights": [ 39 | { 40 | "name": "roberta.encoder.layer.${layer_index}.attention.output.dense.weight" 41 | }, 42 | { 43 | "name": "roberta.encoder.layer.${layer_index}.attention.output.dense.bias" 44 | }, 45 | { 46 | "name": "roberta.encoder.layer.${layer_index}.attention.output.LayerNorm.weight" 47 | }, 48 | { 49 | "name": "roberta.encoder.layer.${layer_index}.attention.output.LayerNorm.bias" 50 | }, 51 | { 52 | "name": "roberta.encoder.layer.${layer_index}.attention.self.query.weight" 53 | }, 54 | { 55 | "name": "roberta.encoder.layer.${layer_index}.attention.self.query.bias" 56 | }, 57 | { 58 | "name": "roberta.encoder.layer.${layer_index}.attention.self.key.weight" 59 | }, 60 | { 61 | "name": "roberta.encoder.layer.${layer_index}.attention.self.key.bias" 62 | }, 63 | { 64 | "name": "roberta.encoder.layer.${layer_index}.attention.self.value.weight" 65 | }, 66 | { 67 | "name": "roberta.encoder.layer.${layer_index}.attention.self.value.bias" 68 | }, 69 | { 70 | "name": "roberta.encoder.layer.${layer_index}.intermediate.dense.weight" 71 | }, 72 | { 73 | "name": "roberta.encoder.layer.${layer_index}.intermediate.dense.bias" 74 | }, 75 | { 76 | "name": "roberta.encoder.layer.${layer_index}.output.dense.weight" 77 | }, 78 | { 79 | "name": "roberta.encoder.layer.${layer_index}.output.dense.bias" 80 | }, 81 | { 82 | "name": "roberta.encoder.layer.${layer_index}.output.LayerNorm.weight" 83 | }, 84 | { 85 | "name": "roberta.encoder.layer.${layer_index}.output.LayerNorm.bias" 86 | } 87 | ] 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/solar.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "solar", 3 | "architectures": [ 4 | "SolarForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true, 10 | "output_space": "running_residual" 11 | } 12 | ], 13 | "num_layers_config_key": "num_hidden_layers", 14 | "layer_templates": { 15 | "weights": [ 16 | { 17 | "name": "model.layers.${layer_index}.input_layernorm.weight", 18 | "input_space": "running_residual" 19 | }, 20 | { 21 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight", 22 | "input_space": "running_residual", 23 | "output_space": "attn_qk_${layer_index}", 24 | "head_split": "output", 25 | "is_kq": true 26 | }, 27 | { 28 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight", 29 | "input_space": "running_residual", 30 | "output_space": "attn_qk_${layer_index}", 31 | "head_split": "output", 32 | "is_kq": true 33 | }, 34 | { 35 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight", 36 | "input_space": "running_residual", 37 | "output_space": "attn_v_${layer_index}", 38 | "head_split": "output" 39 | }, 40 | { 41 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight", 42 | "input_space": "attn_v_${layer_index}", 43 | "output_space": "running_residual", 44 | "head_split": "input" 45 | }, 46 | { 47 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight", 48 | "input_space": "running_residual" 49 | }, 50 | { 51 | "name": "model.layers.${layer_index}.mlp.gate_proj.weight", 52 | "input_space": "running_residual", 53 | "output_space": "up_${layer_index}" 54 | }, 55 | { 56 | "name": "model.layers.${layer_index}.mlp.up_proj.weight", 57 | "input_space": "running_residual", 58 | "output_space": "up_${layer_index}" 59 | }, 60 | { 61 | "name": "model.layers.${layer_index}.mlp.down_proj.weight", 62 | "input_space": "up_${layer_index}", 63 | "output_space": "running_residual" 64 | } 65 | ] 66 | }, 67 | "post_weights": [ 68 | { 69 | "name": "model.norm.weight", 70 | "input_space": "running_residual" 71 | }, 72 | { 73 | "name": "lm_head.weight", 74 | "input_space": "running_residual", 75 | "is_embed": true, 76 | "optional": true, 77 | "tied_names": [ 78 | "model.lm_head.weight" 79 | ] 80 | } 81 | ] 82 | } 83 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/stablelm.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "stablelm_epoch", 3 | "architectures": [ 4 | "StableLMEpochForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true, 10 | "output_space": "h_0" 11 | } 12 | ], 13 | "num_layers_config_key": "num_hidden_layers", 14 | "layer_templates": { 15 | "weights": [ 16 | { 17 | "name": "model.layers.${layer_index}.input_layernorm.weight", 18 | "input_space": "h_${layer_index}" 19 | }, 20 | { 21 | "name": "model.layers.${layer_index}.input_layernorm.bias", 22 | "input_space": "h_${layer_index}" 23 | }, 24 | { 25 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight", 26 | "input_space": "h_${layer_index}", 27 | "output_space": "attn_qk_${layer_index}" 28 | }, 29 | { 30 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight", 31 | "input_space": "h_${layer_index}", 32 | "output_space": "attn_qk_${layer_index}" 33 | }, 34 | { 35 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight", 36 | "input_space": "h_${layer_index}", 37 | "output_space": "attn_v_${layer_index}" 38 | }, 39 | { 40 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight", 41 | "input_space": "attn_v_${layer_index}", 42 | "output_space": "post_attn_${layer_index}" 43 | }, 44 | { 45 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight", 46 | "input_space": "h_a_${layer_index}" 47 | }, 48 | { 49 | "name": "model.layers.${layer_index}.post_attention_layernorm.bias", 50 | "input_space": "h_a_${layer_index}" 51 | }, 52 | { 53 | "name": "model.layers.${layer_index}.mlp.up_proj.weight", 54 | "input_space": "h_a_${layer_index}", 55 | "output_space": "up_${layer_index}" 56 | }, 57 | { 58 | "name": "model.layers.${layer_index}.mlp.gate_proj.weight", 59 | "input_space": "h_a_${layer_index}", 60 | "output_space": "up_${layer_index}" 61 | }, 62 | { 63 | "name": "model.layers.${layer_index}.mlp.down_proj.weight", 64 | "input_space": "up_${layer_index}", 65 | "output_space": "post_mlp_${layer_index}" 66 | } 67 | ], 68 | "procedural_spaces": [ 69 | { 70 | "name": "h_a_${layer_index}", 71 | "type": "residual", 72 | "inputs": [ 73 | "h_${layer_index}", 74 | "post_attn_${layer_index}" 75 | ] 76 | }, 77 | { 78 | "name": "h_${layer_index+1}", 79 | "type": "residual", 80 | "inputs": [ 81 | "h_a_${layer_index}", 82 | "post_mlp_${layer_index}" 83 | ] 84 | } 85 | ] 86 | }, 87 | "post_weights": [ 88 | { 89 | "name": "model.norm.weight", 90 | "input_space": "h_${num_layers}" 91 | }, 92 | { 93 | "name": "lm_head.weight", 94 | "input_space": "h_${num_layers}", 95 | "is_embed": true 96 | } 97 | ] 98 | } 99 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/stablelm2.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "stablelm", 3 | "architectures": [ 4 | "StableLmForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "model.norm.weight" 15 | }, 16 | { 17 | "name": "model.norm.bias" 18 | }, 19 | { 20 | "name": "lm_head.weight", 21 | "is_embed": true 22 | } 23 | ], 24 | "num_layers_config_key": "num_hidden_layers", 25 | "layer_templates": { 26 | "weights": [ 27 | { 28 | "name": "model.layers.${layer_index}.input_layernorm.weight" 29 | }, 30 | { 31 | "name": "model.layers.${layer_index}.input_layernorm.bias" 32 | }, 33 | { 34 | "name": "model.layers.${layer_index}.mlp.down_proj.weight" 35 | }, 36 | { 37 | "name": "model.layers.${layer_index}.mlp.gate_proj.weight" 38 | }, 39 | { 40 | "name": "model.layers.${layer_index}.mlp.up_proj.weight" 41 | }, 42 | { 43 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight" 44 | }, 45 | { 46 | "name": "model.layers.${layer_index}.post_attention_layernorm.bias" 47 | }, 48 | { 49 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight" 50 | }, 51 | { 52 | "name": "model.layers.${layer_index}.self_attn.q_proj.bias", 53 | "optional": true 54 | }, 55 | { 56 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight" 57 | }, 58 | { 59 | "name": "model.layers.${layer_index}.self_attn.k_proj.bias", 60 | "optional": true 61 | }, 62 | { 63 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight" 64 | }, 65 | { 66 | "name": "model.layers.${layer_index}.self_attn.v_proj.bias", 67 | "optional": true 68 | }, 69 | { 70 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight" 71 | } 72 | ] 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /mergekit/_data/architectures/starcoder2.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "starcoder2", 3 | "architectures": [ 4 | "Starcoder2ForCausalLM" 5 | ], 6 | "pre_weights": [ 7 | { 8 | "name": "model.embed_tokens.weight", 9 | "is_embed": true 10 | } 11 | ], 12 | "post_weights": [ 13 | { 14 | "name": "lm_head.weight", 15 | "is_embed": true, 16 | "optional": true, 17 | "tied_names": [ 18 | "model.embed_tokens.weight" 19 | ] 20 | }, 21 | { 22 | "name": "model.norm.bias" 23 | }, 24 | { 25 | "name": "model.norm.weight" 26 | } 27 | ], 28 | "num_layers_config_key": "num_hidden_layers", 29 | "layer_templates": { 30 | "weights": [ 31 | { 32 | "name": "model.layers.${layer_index}.input_layernorm.bias" 33 | }, 34 | { 35 | "name": "model.layers.${layer_index}.input_layernorm.weight" 36 | }, 37 | { 38 | "name": "model.layers.${layer_index}.self_attn.q_proj.bias" 39 | }, 40 | { 41 | "name": "model.layers.${layer_index}.self_attn.q_proj.weight" 42 | }, 43 | { 44 | "name": "model.layers.${layer_index}.self_attn.k_proj.bias" 45 | }, 46 | { 47 | "name": "model.layers.${layer_index}.self_attn.k_proj.weight" 48 | }, 49 | { 50 | "name": "model.layers.${layer_index}.self_attn.v_proj.bias" 51 | }, 52 | { 53 | "name": "model.layers.${layer_index}.self_attn.v_proj.weight" 54 | }, 55 | { 56 | "name": "model.layers.${layer_index}.self_attn.o_proj.bias" 57 | }, 58 | { 59 | "name": "model.layers.${layer_index}.self_attn.o_proj.weight" 60 | }, 61 | { 62 | "name": "model.layers.${layer_index}.post_attention_layernorm.bias" 63 | }, 64 | { 65 | "name": "model.layers.${layer_index}.post_attention_layernorm.weight" 66 | }, 67 | { 68 | "name": "model.layers.${layer_index}.mlp.c_fc.bias" 69 | }, 70 | { 71 | "name": "model.layers.${layer_index}.mlp.c_fc.weight" 72 | }, 73 | { 74 | "name": "model.layers.${layer_index}.mlp.c_proj.bias" 75 | }, 76 | { 77 | "name": "model.layers.${layer_index}.mlp.c_proj.weight" 78 | } 79 | ] 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /mergekit/_data/chat_templates/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arcee-ai/mergekit/488957e8e67c82861ecf63ef761f6bc59122dc74/mergekit/_data/chat_templates/__init__.py -------------------------------------------------------------------------------- /mergekit/_data/chat_templates/alpaca.jinja: -------------------------------------------------------------------------------- 1 | {{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} 2 | 3 | {% for message in messages %} 4 | {% if message['role'] == 'user' %} 5 | ### Instruction: 6 | {{ message['content']|trim -}} 7 | {% if not loop.last %} 8 | 9 | 10 | {% endif %} 11 | {% elif message['role'] == 'assistant' %} 12 | ### Response: 13 | {{ message['content']|trim -}} 14 | {% if not loop.last %} 15 | 16 | 17 | {% endif %} 18 | {% elif message['role'] == 'user_context' %} 19 | ### Input: 20 | {{ message['content']|trim -}} 21 | {% if not loop.last %} 22 | 23 | 24 | {% endif %} 25 | {% endif %} 26 | {% endfor %} 27 | {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %} 28 | ### Response: 29 | {% endif %} 30 | -------------------------------------------------------------------------------- /mergekit/_data/chat_templates/chatml.jinja: -------------------------------------------------------------------------------- 1 | {% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %} 2 | {% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %} 3 | -------------------------------------------------------------------------------- /mergekit/_data/chat_templates/exaone.jinja: -------------------------------------------------------------------------------- 1 | {% for message in messages %} 2 | {% if loop.first and message['role'] != 'system' %} 3 | {{ '[|system|][|endofturn|]\n' }} 4 | {% endif %} 5 | {{ '[|' + message['role'] + '|]' + message['content'] }} 6 | {% if message['role'] == 'user' %} 7 | {{ '\n' }} 8 | {% else %} 9 | {{ '[|endofturn|]\n' }} 10 | {% endif %} 11 | {% endfor %} 12 | {% if add_generation_prompt %} 13 | {{ '[|assistant|]' }} 14 | {% endif %} 15 | -------------------------------------------------------------------------------- /mergekit/_data/chat_templates/llama3.jinja: -------------------------------------------------------------------------------- 1 | {% set loop_messages = messages %} 2 | {% for message in loop_messages %} 3 | {% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %} 4 | {% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %} 5 | {{ content }} 6 | {% endfor %} 7 | {% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %} 8 | -------------------------------------------------------------------------------- /mergekit/_data/chat_templates/mistral.jinja: -------------------------------------------------------------------------------- 1 | {%- if messages[0]['role'] == 'system' %} 2 | {%- set system_message = messages[0]['content'] %} 3 | {%- set loop_messages = messages[1:] %} 4 | {%- else %} 5 | {%- set loop_messages = messages %} 6 | {%- endif %} 7 | 8 | {{- bos_token }} 9 | {%- for message in loop_messages %} 10 | {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} 11 | {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }} 12 | {%- endif %} 13 | {%- if message['role'] == 'user' %} 14 | {%- if loop.first and system_message is defined %} 15 | {{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }} 16 | {%- else %} 17 | {{- ' [INST] ' + message['content'] + ' [/INST]' }} 18 | {%- endif %} 19 | {%- elif message['role'] == 'assistant' %} 20 | {{- ' ' + message['content'] + eos_token}} 21 | {%- else %} 22 | {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }} 23 | {%- endif %} 24 | {%- endfor %} 25 | -------------------------------------------------------------------------------- /mergekit/architecture/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | import logging 5 | from functools import lru_cache 6 | from typing import TYPE_CHECKING, Optional 7 | 8 | from transformers import PretrainedConfig 9 | 10 | from mergekit.architecture.auto import infer_architecture_info 11 | from mergekit.architecture.base import ( 12 | ConfiguredModelArchitecture, 13 | ConfiguredModuleArchitecture, 14 | ModelArchitecture, 15 | ModuleArchitecture, 16 | ModuleDefinition, 17 | WeightInfo, 18 | ) 19 | from mergekit.architecture.json_definitions import NAME_TO_ARCH 20 | from mergekit.architecture.moe_defs import ( 21 | MixtralModuleArchitecture, 22 | Qwen3MoeModuleArchitecture, 23 | ) 24 | from mergekit.options import MergeOptions 25 | 26 | if TYPE_CHECKING: 27 | from mergekit.config import MergeConfiguration 28 | 29 | LOG = logging.getLogger(__name__) 30 | 31 | WARNED_ARCHITECTURE_NAMES = set() 32 | 33 | 34 | def arch_info_for_config(config: PretrainedConfig) -> Optional[ModelArchitecture]: 35 | if len(config.architectures) != 1: 36 | raise RuntimeError("More than one architecture in config?") 37 | arch_name = config.architectures[0] 38 | 39 | if arch_name == MixtralModuleArchitecture.ARCHITECTURE_NAME: 40 | module = MixtralModuleArchitecture.from_config(config) 41 | return ModelArchitecture( 42 | modules={"default": ModuleDefinition(architecture=module)}, 43 | architectures=[arch_name], 44 | model_type="mixtral", 45 | ) 46 | elif arch_name == Qwen3MoeModuleArchitecture.ARCHITECTURE_NAME: 47 | module = Qwen3MoeModuleArchitecture.from_config(config) 48 | return ModelArchitecture( 49 | modules={"default": ModuleDefinition(architecture=module)}, 50 | architectures=[arch_name], 51 | model_type="qwen3_moe", 52 | ) 53 | elif arch_name in NAME_TO_ARCH: 54 | candidates = list(NAME_TO_ARCH[arch_name]) 55 | if len(candidates) == 1: 56 | return candidates[0] 57 | 58 | for c in candidates: 59 | if c.expected_model_type == config.model_type: 60 | return c 61 | LOG.warning( 62 | f"Multiple architectures for {arch_name}, none match model type {config.model_type}" 63 | ) 64 | 65 | if arch_name not in WARNED_ARCHITECTURE_NAMES: 66 | LOG.warning(f"No JSON architecture found for {arch_name}") 67 | WARNED_ARCHITECTURE_NAMES.add(arch_name) 68 | return None 69 | 70 | 71 | def get_architecture_info( 72 | config: "MergeConfiguration", options: MergeOptions 73 | ) -> ModelArchitecture: 74 | models = config.referenced_models() 75 | if not models: 76 | raise ValueError("No models referenced in config") 77 | 78 | model_arch_info = [ 79 | arch_info_for_config(m.config(trust_remote_code=options.trust_remote_code)) 80 | for m in models 81 | ] 82 | if all(arch is not None for arch in model_arch_info): 83 | if not options.allow_crimes and any( 84 | arch != model_arch_info[0] for arch in model_arch_info 85 | ): 86 | raise RuntimeError( 87 | "Must specify --allow-crimes to attempt to mix different architectures" 88 | ) 89 | return model_arch_info[0] 90 | 91 | # try to infer from all models 92 | return infer_architecture_info(tuple(models), config.base_model, options) 93 | 94 | 95 | __all__ = [ 96 | "ModelArchitecture", 97 | "ModuleArchitecture", 98 | "ModuleDefinition", 99 | "ConfiguredModuleArchitecture", 100 | "ConfiguredModelArchitecture", 101 | "WeightInfo", 102 | "get_architecture_info", 103 | "arch_info_for_config", 104 | ] 105 | -------------------------------------------------------------------------------- /mergekit/architecture/moe_defs.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from typing import ClassVar, List, Optional 5 | 6 | from pydantic import BaseModel 7 | from transformers import PretrainedConfig 8 | 9 | from mergekit.architecture.base import ( 10 | ModuleArchitecture, 11 | WeightInfo, 12 | ) 13 | from mergekit.architecture.json_definitions import NAME_TO_ARCH 14 | 15 | MISTRAL_INFO = NAME_TO_ARCH["MistralForCausalLM"][0] 16 | MISTRAL_MODULE_ARCH = MISTRAL_INFO.modules["default"].architecture 17 | 18 | 19 | class MixtralModuleArchitecture(ModuleArchitecture, BaseModel): 20 | ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM" 21 | num_local_experts: int 22 | 23 | def name(self) -> str: 24 | return "mixtral" 25 | 26 | @classmethod 27 | def from_config(cls, config: PretrainedConfig): 28 | return MixtralModuleArchitecture(num_local_experts=config.num_local_experts) 29 | 30 | def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: 31 | return MISTRAL_MODULE_ARCH.pre_weights(config) 32 | 33 | def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: 34 | return MISTRAL_MODULE_ARCH.post_weights(config) 35 | 36 | def num_layers_config_key(self) -> str: 37 | return MISTRAL_MODULE_ARCH.num_layers_config_key() 38 | 39 | def layer_weights( 40 | self, index: int, config: PretrainedConfig 41 | ) -> Optional[List[WeightInfo]]: 42 | num_experts = self.num_local_experts 43 | prefix = f"model.layers.{index}" 44 | tensor_names = [] 45 | for expert_idx in range(num_experts): 46 | for param in ("w1", "w2", "w3"): 47 | tensor_names.append( 48 | prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight" 49 | ) 50 | tensor_names.append(prefix + ".block_sparse_moe.gate.weight") 51 | res = [] 52 | for name in tensor_names: 53 | res.append(WeightInfo(name=name)) 54 | for weight_info in MISTRAL_MODULE_ARCH.layer_weights(index, config): 55 | if ".mlp." in weight_info.name: 56 | continue 57 | res.append(weight_info) 58 | return res 59 | 60 | 61 | QWEN3_INFO = NAME_TO_ARCH["Qwen3ForCausalLM"][0] 62 | QWEN3_MODULE_ARCH = QWEN3_INFO.modules["default"].architecture 63 | 64 | 65 | class Qwen3MoeModuleArchitecture(ModuleArchitecture, BaseModel): 66 | ARCHITECTURE_NAME: ClassVar[str] = "Qwen3MoeForCausalLM" 67 | num_experts: int 68 | 69 | def name(self) -> str: 70 | return "qwen3_moe" 71 | 72 | @classmethod 73 | def from_config(cls, config: PretrainedConfig): 74 | return Qwen3MoeModuleArchitecture(num_experts=config.num_experts) 75 | 76 | def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: 77 | return QWEN3_MODULE_ARCH.pre_weights(config) 78 | 79 | def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: 80 | return QWEN3_MODULE_ARCH.post_weights(config) 81 | 82 | def num_layers_config_key(self) -> str: 83 | return QWEN3_MODULE_ARCH.num_layers_config_key() 84 | 85 | def layer_weights( 86 | self, index: int, config: PretrainedConfig 87 | ) -> Optional[List[WeightInfo]]: 88 | prefix = f"model.layers.{index}" 89 | tensor_names = [] 90 | for expert_idx in range(self.num_experts): 91 | for param in ("up_proj", "gate_proj", "down_proj"): 92 | tensor_names.append( 93 | prefix + f".mlp.experts.{expert_idx}.{param}.weight" 94 | ) 95 | tensor_names.append(prefix + ".mlp.gate.weight") 96 | res = [] 97 | for name in tensor_names: 98 | res.append(WeightInfo(name=name)) 99 | for weight_info in QWEN3_MODULE_ARCH.layer_weights(index, config): 100 | if ".mlp." in weight_info.name: 101 | continue 102 | res.append(weight_info) 103 | return res 104 | -------------------------------------------------------------------------------- /mergekit/evo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arcee-ai/mergekit/488957e8e67c82861ecf63ef761f6bc59122dc74/mergekit/evo/__init__.py -------------------------------------------------------------------------------- /mergekit/evo/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | import logging 5 | from typing import List, Optional 6 | 7 | from pydantic import BaseModel, model_validator 8 | 9 | from mergekit.evo.genome import ModelGenomeDefinition 10 | 11 | 12 | class TaskConfiguration(BaseModel, frozen=True): 13 | name: str 14 | weight: float = 1.0 15 | metric: str = "acc,none" 16 | 17 | @model_validator(mode="before") 18 | def validate_string(cls, value): 19 | if isinstance(value, str): 20 | return {"name": value} 21 | return value 22 | 23 | 24 | class EvolMergeConfiguration(BaseModel, frozen=True): 25 | genome: ModelGenomeDefinition 26 | tasks: List[TaskConfiguration] 27 | limit: Optional[int] = None 28 | num_fewshot: Optional[int] = None 29 | shuffle: bool = False 30 | random_init: bool = False 31 | apply_chat_template: bool = True 32 | fewshot_as_multiturn: bool = True 33 | 34 | 35 | NAUGHTY_PREFIXES = [ 36 | "mmlu", 37 | "hendrycks", 38 | "agieval", 39 | "gsm8k", 40 | "hellaswag", 41 | "winogrande", 42 | "arc_", 43 | "ai2_arc", 44 | "truthfulqa", 45 | "bigbench", 46 | "piqa", 47 | "openbookqa", 48 | "leaderboard", 49 | ] 50 | 51 | 52 | def check_for_naughty_config(config: EvolMergeConfiguration, allow: bool = False): 53 | """ 54 | Check if the given configuration is naughty and should be disallowed. 55 | 56 | mergekit-evolve is perfectly set up to directly optimize against the test set 57 | of common benchmarks, which just makes the world a worse place. There are 58 | cases where this is useful but it deserves a giant honking warning. 59 | """ 60 | suffix = "" 61 | if not allow: 62 | suffix = ( 63 | " To proceed, set the " 64 | "--i-understand-the-depths-of-the-evils-i-am-unleashing flag." 65 | ) 66 | for task in config.tasks: 67 | for prefix in NAUGHTY_PREFIXES: 68 | if task.name.startswith(prefix): 69 | if task.name.endswith("_train"): 70 | # there aren't any tasks that match this pattern in base 71 | # lm-eval, but it'd be a sane thing to do to add tasks for 72 | # the training sets of these benchmarks. don't warn about 73 | # them 74 | continue 75 | 76 | message = ( 77 | f"Task {task.name} is a common benchmark task. " 78 | "Optimizing against this task directly is unsporting at best " 79 | "and outright malicious at worst. Using mergekit-evolve to " 80 | "game benchmarks will be a black mark on your name for a " 81 | f"thousand generations.{suffix}" 82 | ) 83 | if not allow: 84 | raise ValueError(message) 85 | else: 86 | logging.warning(message) 87 | -------------------------------------------------------------------------------- /mergekit/evo/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | import logging 5 | import os 6 | import shutil 7 | import tempfile 8 | from typing import Any, Dict, List, Optional, Union 9 | 10 | import lm_eval 11 | import lm_eval.api.model 12 | import lm_eval.models.huggingface 13 | import lm_eval.tasks 14 | import ray 15 | import ray.util.queue 16 | import ray.util.scheduling_strategies 17 | import torch 18 | 19 | from mergekit.evo.config import TaskConfiguration 20 | from mergekit.evo.genome import InvalidGenotypeError, ModelGenome 21 | from mergekit.evo.monkeypatch import monkeypatch_lmeval_vllm 22 | from mergekit.merge import run_merge 23 | from mergekit.options import MergeOptions 24 | 25 | 26 | def _eval_model( 27 | model: Union[str, lm_eval.api.model.LM], 28 | tasks: List[TaskConfiguration], 29 | model_args: Optional[Dict[str, Any]] = None, 30 | task_manager: Optional[lm_eval.tasks.TaskManager] = None, 31 | **kwargs, 32 | ) -> Dict[str, Any]: 33 | results = lm_eval.evaluator.simple_evaluate( 34 | model=model, 35 | model_args=model_args, 36 | tasks=list(set([task.name for task in tasks])), 37 | log_samples=False, 38 | verbosity="WARNING", 39 | task_manager=task_manager, 40 | **kwargs, 41 | ) 42 | 43 | logging.info(results["results"]) 44 | res = 0 45 | for task in tasks: 46 | res += results["results"][task.name][task.metric] * task.weight 47 | return {"score": res, "results": results["results"]} 48 | 49 | 50 | def evaluate_model( 51 | merged_path: str, 52 | tasks: List[TaskConfiguration], 53 | num_fewshot: Optional[int], 54 | limit: Optional[int], 55 | vllm: bool, 56 | batch_size: Optional[int] = None, 57 | task_manager: Optional[lm_eval.tasks.TaskManager] = None, 58 | model_kwargs: Optional[Dict[str, Any]] = None, 59 | **kwargs, 60 | ) -> dict: 61 | # monkeypatch_tqdm() 62 | monkeypatch_lmeval_vllm() 63 | try: 64 | model_args = { 65 | "pretrained": merged_path, 66 | "dtype": "bfloat16", 67 | **(model_kwargs or {}), 68 | } 69 | if vllm: 70 | model_args["gpu_memory_utilization"] = 0.8 71 | model_args["tensor_parallel_size"] = 1 72 | model_args["batch_size"] = "auto" 73 | model_args["max_model_len"] = 4096 74 | else: 75 | model_args["use_cache"] = True 76 | 77 | res = _eval_model( 78 | "vllm" if vllm else "huggingface", 79 | tasks, 80 | model_args, 81 | num_fewshot=num_fewshot, 82 | limit=limit, 83 | batch_size=batch_size, 84 | task_manager=task_manager, 85 | **kwargs, 86 | ) 87 | return res 88 | finally: 89 | shutil.rmtree(merged_path) 90 | 91 | 92 | evaluate_model_ray = ray.remote(num_cpus=1, num_gpus=1.0)(evaluate_model) 93 | 94 | 95 | def merge_model( 96 | genotype: torch.Tensor, 97 | genome: ModelGenome, 98 | model_storage_path: str, 99 | merge_options: MergeOptions, 100 | ) -> str: 101 | # monkeypatch_tqdm() 102 | try: 103 | cfg = genome.genotype_merge_config(genotype) 104 | except InvalidGenotypeError as e: 105 | logging.error("Invalid genotype", exc_info=e) 106 | return None 107 | os.makedirs(model_storage_path, exist_ok=True) 108 | res = tempfile.mkdtemp(prefix="merged", dir=model_storage_path) 109 | run_merge(cfg, out_path=res, options=merge_options) 110 | return res 111 | 112 | 113 | merge_model_ray = ray.remote( 114 | num_cpus=1, 115 | num_gpus=1, 116 | max_retries=3, 117 | retry_exceptions=[ConnectionError], 118 | )(merge_model) 119 | -------------------------------------------------------------------------------- /mergekit/evo/monkeypatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | 5 | import torch 6 | import transformers 7 | 8 | 9 | def monkeypatch_lmeval_shuffle(): 10 | """Monkeypatch lm_eval to shuffle the dataset after downloading.""" 11 | import lm_eval.api.task 12 | 13 | if hasattr(lm_eval.api.task.Task, "_monkey_patched"): 14 | return 15 | 16 | _old_task_dl = lm_eval.api.task.Task.download 17 | 18 | def _dl_shuffled(self: lm_eval.api.task.Task, *args, **kwargs): 19 | _old_task_dl(self, *args, **kwargs) 20 | self.dataset = self.dataset.shuffle() 21 | 22 | lm_eval.api.task.Task.download = _dl_shuffled 23 | 24 | _old_ct_dl = lm_eval.api.task.ConfigurableTask.download 25 | 26 | def _ct_dl_shuffled(self, *args, **kwargs): 27 | _old_ct_dl(self, *args, **kwargs) 28 | self.dataset = self.dataset.shuffle() 29 | 30 | lm_eval.api.task.ConfigurableTask.download = _ct_dl_shuffled 31 | 32 | lm_eval.api.task.Task._monkey_patched = True 33 | print("monkey has been patched") 34 | 35 | 36 | def monkeypatch_tqdm(lm_eval: bool = True, mergekit: bool = True): 37 | """Patch lm_eval & mergekit to use Ray's tqdm for progress bars.""" 38 | 39 | from ray.experimental.tqdm_ray import tqdm as tqdm_ray 40 | 41 | def _tqdm_wrap(iterable=None, disable: bool = False, **kwargs): 42 | if disable: 43 | if iterable is not None: 44 | return iterable 45 | return lambda x: x 46 | res = tqdm_ray(iterable=iterable, **kwargs, flush_interval_s=1.0) 47 | res.refresh() 48 | return res 49 | 50 | def _patch_lm_eval(): 51 | import lm_eval 52 | 53 | if hasattr(lm_eval, "_mk_tqdm_patched"): 54 | return 55 | 56 | import lm_eval.api.metrics 57 | import lm_eval.api.model 58 | import lm_eval.api.task 59 | import lm_eval.models.huggingface 60 | import lm_eval.models.vllm_causallms 61 | 62 | for module in ( 63 | lm_eval.models.huggingface, 64 | lm_eval.models.vllm_causallms, 65 | lm_eval.api.model, 66 | lm_eval.api.task, 67 | lm_eval.api.metrics, 68 | ): 69 | setattr(module, "tqdm", _tqdm_wrap) 70 | 71 | lm_eval._mk_tqdm_patched = True 72 | 73 | if lm_eval: 74 | _patch_lm_eval() 75 | 76 | if mergekit: 77 | del mergekit 78 | 79 | import mergekit 80 | import mergekit.graph 81 | import mergekit.merge 82 | import mergekit.tokenizer 83 | 84 | fake_module = type("fake_module", (), {"tqdm": staticmethod(_tqdm_wrap)})() 85 | 86 | mergekit.graph.tqdm = fake_module 87 | mergekit.merge.tqdm = fake_module 88 | mergekit.tokenizer.tqdm = fake_module 89 | 90 | 91 | def monkeypatch_lmeval_vllm(): 92 | # HACK: fix crash on some tasks due to unset AUTO_MODEL_CLASS for vLLM 93 | import lm_eval.models.vllm_causallms 94 | 95 | lm_eval.models.vllm_causallms.VLLM.AUTO_MODEL_CLASS = ( 96 | transformers.AutoModelForCausalLM 97 | ) 98 | 99 | 100 | class NoInit: 101 | def __enter__(self): 102 | def noop(*args, **kwargs): 103 | pass 104 | 105 | (k, u, n) = ( 106 | torch.nn.init.kaiming_uniform_, 107 | torch.nn.init.uniform_, 108 | torch.nn.init.normal_, 109 | ) 110 | torch.nn.init.kaiming_uniform_ = noop 111 | torch.nn.init.uniform_ = noop 112 | torch.nn.init.normal_ = noop 113 | 114 | transformers.modeling_utils._init_weights = False 115 | self.funcs = (k, u, n) 116 | 117 | def __exit__(self, *args): 118 | (k, u, n) = self.funcs 119 | ( 120 | torch.nn.init.kaiming_uniform_, 121 | torch.nn.init.uniform_, 122 | torch.nn.init.normal_, 123 | ) = ( 124 | k, 125 | u, 126 | n, 127 | ) 128 | transformers.modeling_utils._init_weights = True 129 | -------------------------------------------------------------------------------- /mergekit/io/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from mergekit.io.lazy_tensor_loader import ( 5 | LazyTensorLoader, 6 | ShardedTensorIndex, 7 | ShardInfo, 8 | ) 9 | from mergekit.io.tensor_writer import TensorWriter 10 | 11 | __all__ = [ 12 | "LazyTensorLoader", 13 | "ShardedTensorIndex", 14 | "ShardInfo", 15 | "TensorWriter", 16 | ] 17 | -------------------------------------------------------------------------------- /mergekit/io/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from abc import ABC, abstractmethod 5 | from typing import Dict, Optional, Sequence 6 | 7 | import safetensors 8 | import torch 9 | 10 | from mergekit.io.lazy_unpickle import ( 11 | DeferredLoad, 12 | LazyUnpickleModule, 13 | TorchArchiveReader, 14 | torch_lazy_load, 15 | ) 16 | 17 | 18 | class TensorLoader(ABC): 19 | """Base class for (potentially lazy) tensor loaders.""" 20 | 21 | @abstractmethod 22 | def get_tensor(self, key: str) -> torch.Tensor: ... 23 | 24 | @abstractmethod 25 | def keys(self) -> Sequence[str]: ... 26 | 27 | @classmethod 28 | def get( 29 | cls, 30 | shard_path: str, 31 | use_lazy_unpickle: bool = False, 32 | device: Optional[str] = None, 33 | ) -> "TensorLoader": 34 | if shard_path.lower().endswith(".safetensors"): 35 | # not a subclass of TensorLoader, but exposes same api 36 | return safetensors.safe_open( 37 | shard_path, framework="pt", device=device or "cpu" 38 | ) 39 | elif use_lazy_unpickle: 40 | return LazyPickleLoader(shard_path, device=device) 41 | return DumbPytorchLoader(shard_path, device=device) 42 | 43 | 44 | class LazyPickleLoader(TensorLoader): 45 | """Loader for pytorch files using a custom unpickler and vigorous monkeypatching.""" 46 | 47 | zip_reader: TorchArchiveReader 48 | index: Dict[str, DeferredLoad] 49 | device: Optional[str] = None 50 | 51 | def __init__(self, path: str, device: Optional[str] = None): 52 | self.zip_reader = TorchArchiveReader(path) 53 | self.device = device 54 | with torch_lazy_load(): 55 | self.index = torch.load(path, pickle_module=LazyUnpickleModule) 56 | 57 | def get_tensor(self, key: str) -> torch.Tensor: 58 | if key not in self.index: 59 | raise KeyError(key) 60 | 61 | return self.index[key].execute(self.zip_reader, map_location=self.device) 62 | 63 | def keys(self) -> Sequence[str]: 64 | return self.index.keys() 65 | 66 | 67 | class DumbPytorchLoader(TensorLoader): 68 | """Naive `torch.load` shard loading.""" 69 | 70 | tensors: Dict[str, torch.Tensor] 71 | 72 | def __init__(self, path: str, device: Optional[str] = None): 73 | self.tensors = torch.load(path, map_location=device, weights_only=True) 74 | 75 | def get_tensor(self, key: str) -> torch.Tensor: 76 | return self.tensors[key] 77 | 78 | def keys(self) -> Sequence[str]: 79 | return self.tensors.keys() 80 | -------------------------------------------------------------------------------- /mergekit/merge_methods/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | import mergekit.merge_methods.multislerp 5 | import mergekit.merge_methods.nearswap 6 | import mergekit.merge_methods.sce 7 | from mergekit.merge_methods.base import MergeMethod 8 | from mergekit.merge_methods.generalized_task_arithmetic import ( 9 | GeneralizedTaskArithmeticMerge, 10 | ) 11 | from mergekit.merge_methods.registry import REGISTERED_MERGE_METHODS 12 | 13 | 14 | def get(method: str) -> MergeMethod: 15 | if method in REGISTERED_MERGE_METHODS: 16 | return REGISTERED_MERGE_METHODS[method] 17 | raise RuntimeError(f"Unimplemented merge method {method}") 18 | 19 | 20 | __all__ = [ 21 | "MergeMethod", 22 | "get", 23 | "GeneralizedTaskArithmeticMerge", 24 | "REGISTERED_MERGE_METHODS", 25 | ] 26 | -------------------------------------------------------------------------------- /mergekit/merge_methods/arcee_fusion.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from typing import Dict, List, Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from typing_extensions import override 9 | 10 | from mergekit.architecture import WeightInfo 11 | from mergekit.common import ModelReference 12 | from mergekit.graph import Task 13 | from mergekit.merge_methods.base import ( 14 | ConfigParameterDef, 15 | MergeMethod, 16 | MergeTensorInput, 17 | ) 18 | from mergekit.merge_methods.rectify_embed import rectify_embed_sizes 19 | 20 | 21 | class DynamicThresholdFusion: 22 | def approximate_quantiles(self, tensor, q): 23 | # Flatten the tensor 24 | flat_tensor = tensor.view(-1) 25 | 26 | # If tensor is too large, sample it 27 | if flat_tensor.numel() > 1e6: 28 | flat_tensor = flat_tensor[torch.randperm(flat_tensor.numel())[:1000000]] 29 | 30 | # Sort the (possibly sampled) tensor 31 | sorted_tensor, _ = torch.sort(flat_tensor) 32 | 33 | # Compute quantile indices 34 | quantile_indices = (q * (sorted_tensor.numel() - 1)).long() 35 | 36 | # Return quantiles 37 | return sorted_tensor[quantile_indices] 38 | 39 | def calculate_dynamic_threshold(self, importance_scores): 40 | # Approximate median and quantiles 41 | median = self.approximate_quantiles(importance_scores, torch.tensor([0.5]))[0] 42 | q1, q3 = self.approximate_quantiles( 43 | importance_scores, torch.tensor([0.25, 0.75]) 44 | ) 45 | 46 | # Calculate IQR 47 | iqr = q3 - q1 48 | 49 | # Set threshold as median + 1.5 * IQR 50 | dynamic_threshold = median + 1.5 * iqr 51 | 52 | return dynamic_threshold 53 | 54 | def compute_fusion_mask(self, importance_scores): 55 | threshold = self.calculate_dynamic_threshold(importance_scores) 56 | fusion_mask = (importance_scores >= threshold).float() 57 | return fusion_mask, threshold 58 | 59 | 60 | class ArceeFusionMergeTask(Task[torch.Tensor]): 61 | gather_tensors: MergeTensorInput 62 | base_model: ModelReference 63 | weight_info: WeightInfo 64 | 65 | def uses_accelerator(self) -> bool: 66 | return True 67 | 68 | def arguments(self) -> Dict[str, Task]: 69 | return {"tensors": self.gather_tensors} 70 | 71 | def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor: 72 | if len(tensors) == 1: 73 | return list(tensors.values())[0] 74 | elif len(tensors) != 2: 75 | raise RuntimeError("ArceeFusion merge expects exactly two models") 76 | elif self.base_model not in tensors: 77 | raise RuntimeError("Base model not in input tensors") 78 | 79 | [a, b] = list(tensors.items()) 80 | if a[0] != self.base_model: 81 | [a, b] = [b, a] 82 | prepped_tensors = [a[1], b[1]] 83 | 84 | rectify_embed_sizes(self.weight_info, prepped_tensors) 85 | 86 | importance_scores = self._compute_importance( 87 | prepped_tensors[1], prepped_tensors[0] 88 | ) 89 | dynamic_threshold_fusion = DynamicThresholdFusion() 90 | fusion_mask, _threshold = dynamic_threshold_fusion.compute_fusion_mask( 91 | importance_scores 92 | ) 93 | 94 | delta = prepped_tensors[1] - prepped_tensors[0] 95 | masked_delta = delta * fusion_mask 96 | fused = prepped_tensors[0] + masked_delta 97 | 98 | return fused 99 | 100 | def _compute_importance( 101 | self, params: torch.Tensor, base_params: torch.Tensor, eps: float = 1e-8 102 | ) -> torch.Tensor: 103 | diff = (params - base_params).abs() 104 | p = F.softmax(params, dim=-1) + eps 105 | q = F.softmax(base_params, dim=-1) + eps 106 | kl_div = torch.sum(p * torch.log(p / q), dim=-1) 107 | return diff * kl_div.unsqueeze(-1) 108 | 109 | 110 | class ArceeFusionMerge(MergeMethod): 111 | def name(self) -> str: 112 | return "arcee_fusion" 113 | 114 | @override 115 | def pretty_name(self) -> Optional[str]: 116 | return "Arcee Fusion" 117 | 118 | @override 119 | def reference_url(self) -> Optional[str]: 120 | return "https://arcee.ai" 121 | 122 | def parameters(self) -> List[ConfigParameterDef]: 123 | return [] 124 | 125 | def make_task( 126 | self, 127 | output_weight: WeightInfo, 128 | tensors: MergeTensorInput, 129 | base_model: Optional[ModelReference], 130 | **kwargs, 131 | ) -> Task[torch.Tensor]: 132 | return ArceeFusionMergeTask( 133 | gather_tensors=tensors, weight_info=output_weight, base_model=base_model 134 | ) 135 | -------------------------------------------------------------------------------- /mergekit/merge_methods/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from abc import ABC, abstractmethod 5 | from typing import Any, Dict, List, Optional, Union 6 | 7 | import torch 8 | from pydantic import BaseModel 9 | from typing_extensions import TypeAlias 10 | 11 | from mergekit.architecture import WeightInfo 12 | from mergekit.common import ImmutableMap, ModelReference 13 | from mergekit.graph import Task 14 | from mergekit.io.tasks import GatherTensors 15 | from mergekit.tokenizer import PermutedEmbeddings 16 | 17 | 18 | class TensorDictWrapper(Task[Dict[ModelReference, torch.Tensor]]): 19 | tensors: ImmutableMap[ModelReference, Task[torch.Tensor]] 20 | 21 | def arguments(self) -> Dict[str, Task]: 22 | return { 23 | k.model_dump_json( 24 | exclude_none=True, exclude_defaults=True, round_trip=True 25 | ): v 26 | for k, v in self.tensors.items() 27 | } 28 | 29 | def execute(self, **kwargs) -> Dict[ModelReference, torch.Tensor]: 30 | return {ModelReference.model_validate_json(k): v for k, v in kwargs.items()} 31 | 32 | 33 | MergeTensorInput: TypeAlias = Union[ 34 | GatherTensors, PermutedEmbeddings, TensorDictWrapper 35 | ] 36 | 37 | 38 | class ConfigParameterDef(BaseModel): 39 | name: str 40 | required: bool = False 41 | default_value: Any = None 42 | 43 | 44 | class MergeMethod(ABC): 45 | def tensor_parameters(self) -> List[ConfigParameterDef]: 46 | return [] 47 | 48 | def parameters(self) -> List[ConfigParameterDef]: 49 | return [] 50 | 51 | @abstractmethod 52 | def name(self) -> str: ... 53 | 54 | def pretty_name(self) -> Optional[str]: 55 | return None 56 | 57 | def reference_url(self) -> Optional[str]: 58 | return None 59 | 60 | @abstractmethod 61 | def make_task( 62 | self, 63 | *, 64 | output_weight: WeightInfo, 65 | tensors: MergeTensorInput, 66 | parameters: ImmutableMap[str, Any], 67 | tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], 68 | base_model: Optional[ModelReference], 69 | ) -> Task: ... 70 | -------------------------------------------------------------------------------- /mergekit/merge_methods/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from typing import Any, Dict, List, Optional 5 | 6 | import torch 7 | from typing_extensions import override 8 | 9 | from mergekit.architecture import WeightInfo 10 | from mergekit.common import ImmutableMap, ModelReference 11 | from mergekit.graph import Task 12 | from mergekit.merge_methods.base import ( 13 | ConfigParameterDef, 14 | MergeMethod, 15 | MergeTensorInput, 16 | ) 17 | from mergekit.merge_methods.rectify_embed import rectify_embed_sizes 18 | 19 | 20 | class LinearMergeTask(Task[torch.Tensor]): 21 | gather_tensors: MergeTensorInput 22 | tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] 23 | normalize: bool 24 | weight_info: WeightInfo 25 | 26 | def uses_accelerator(self) -> bool: 27 | return True 28 | 29 | def arguments(self) -> Dict[str, Task]: 30 | return {"tensors": self.gather_tensors} 31 | 32 | def execute( 33 | self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs 34 | ) -> torch.Tensor: 35 | keys = list(tensors.keys()) 36 | 37 | tensors = [tensors[key] for key in keys] 38 | weights = [self.tensor_parameters[key]["weight"] for key in keys] 39 | 40 | rectify_embed_sizes(self.weight_info, tensors) 41 | 42 | unique_shapes = set(t.shape for t in tensors) 43 | if len(unique_shapes) != 1: 44 | raise RuntimeError( 45 | f"Tensor size mismatch for {self.weight_info.name}, sizes: {list(unique_shapes)}" 46 | ) 47 | 48 | tensors = torch.stack(tensors, dim=0) 49 | weights = torch.tensor(weights, dtype=tensors.dtype, device=tensors.device) 50 | while len(weights.shape) < len(tensors.shape): 51 | weights.unsqueeze_(-1) 52 | 53 | res = (weights * tensors).sum(dim=0) 54 | if self.normalize: 55 | res = res / weights.sum(dim=0) 56 | 57 | return res 58 | 59 | def group_label(self) -> Optional[str]: 60 | return self.gather_tensors.group_label() 61 | 62 | 63 | class LinearMerge(MergeMethod): 64 | def name(self) -> str: 65 | return "linear" 66 | 67 | @override 68 | def pretty_name(self) -> Optional[str]: 69 | return "Linear" 70 | 71 | @override 72 | def reference_url(self) -> Optional[str]: 73 | return "https://arxiv.org/abs/2203.05482" 74 | 75 | def parameters(self) -> List[ConfigParameterDef]: 76 | return [ 77 | ConfigParameterDef(name="normalize", required=False, default_value=True), 78 | ] 79 | 80 | def tensor_parameters(self) -> List[ConfigParameterDef]: 81 | return [ConfigParameterDef(name="weight", required=True)] 82 | 83 | def make_task( 84 | self, 85 | *, 86 | output_weight: WeightInfo, 87 | tensors: MergeTensorInput, 88 | parameters: Dict[str, Any], 89 | tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], 90 | **_kwargs, 91 | ) -> Task: 92 | return LinearMergeTask( 93 | gather_tensors=tensors, 94 | tensor_parameters=tensor_parameters, 95 | normalize=parameters["normalize"], 96 | weight_info=output_weight, 97 | ) 98 | -------------------------------------------------------------------------------- /mergekit/merge_methods/model_stock.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | import logging 5 | from typing import Any, Dict, List, Optional 6 | 7 | import torch 8 | from typing_extensions import override 9 | 10 | from mergekit.architecture import WeightInfo 11 | from mergekit.common import ImmutableMap, ModelReference 12 | from mergekit.graph import Task 13 | from mergekit.merge_methods.base import ( 14 | ConfigParameterDef, 15 | MergeMethod, 16 | MergeTensorInput, 17 | ) 18 | from mergekit.merge_methods.rectify_embed import rectify_embed_sizes 19 | 20 | 21 | class ModelStockMergeTask(Task[torch.Tensor]): 22 | gather_tensors: MergeTensorInput 23 | base_model: ModelReference 24 | weight_info: WeightInfo 25 | filter_wise: bool = False 26 | 27 | def uses_accelerator(self) -> bool: 28 | return True 29 | 30 | def arguments(self) -> Dict[str, Task]: 31 | return {"tensors": self.gather_tensors} 32 | 33 | def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor: 34 | if len(tensors) == 1 and self.base_model in tensors: 35 | return tensors[self.base_model] 36 | if len(tensors) < 3: 37 | if self.weight_info.optional: 38 | logging.warning( 39 | f"Optional weight {self.weight_info.name} not present in enough models, discarding" 40 | ) 41 | return None 42 | 43 | raise ValueError( 44 | "ModelStockMerge requires at least 3 models (base plus two+ others)" 45 | ) 46 | 47 | w_0, ws = self.get_rectified_weights(tensors) 48 | out_shape = w_0.shape 49 | 50 | if self.filter_wise: 51 | if w_0.dim() == 1: 52 | # bias (or other single-vector) parameters should be treated as row vectors 53 | w_0 = w_0.unsqueeze(0) 54 | ws = [w.unsqueeze(0) for w in ws] 55 | else: 56 | w_0 = w_0.view(-1) 57 | ws = [w.view(-1) for w in ws] 58 | 59 | offsets = [w - w_0 for w in ws] 60 | 61 | # now there is a question of how to come up with a value for theta. 62 | # in the two-vector case, we can get an exact angle between the two vectors 63 | # but the paper doesn't explicitly say what to do in the multi-vector case - 64 | # they keep using a singular theta value and don't elaborate on how to 65 | # calculate it. i'm going to assume an average of pairwise angles for now? i guess? 66 | 67 | cos_thetas = [] 68 | for i, w_0_offset in enumerate(offsets): 69 | for j in range(i + 1, len(offsets)): 70 | w_1_offset = offsets[j] 71 | 72 | norm_product = torch.norm(w_0_offset, dim=-1) * torch.norm( 73 | w_1_offset, dim=-1 74 | ) 75 | cos_theta = ( 76 | (w_0_offset * w_1_offset).sum(dim=-1) / norm_product.clamp(min=1e-6) 77 | ).clamp(-1, 1) 78 | cos_thetas.append(cos_theta) 79 | 80 | cos_theta = torch.stack(cos_thetas).mean(dim=0).unsqueeze(-1) 81 | N = len(ws) 82 | t = (N * cos_theta) / (1 + (N - 1) * cos_theta) 83 | 84 | w_avg = sum(ws) / len(ws) 85 | w_h = t * w_avg + (1 - t) * w_0 86 | 87 | return w_h.reshape(out_shape) 88 | 89 | def get_rectified_weights(self, tensors: Dict[ModelReference, torch.Tensor]): 90 | if self.base_model not in tensors: 91 | raise ValueError("Base model tensor not found") 92 | 93 | all_weights = [tensors[self.base_model]] + [ 94 | tensors[k] for k in tensors if k != self.base_model 95 | ] 96 | rectify_embed_sizes(self.weight_info, all_weights) 97 | w_0 = all_weights[0] 98 | ws = all_weights[1:] 99 | return w_0, ws 100 | 101 | def group_label(self) -> Optional[str]: 102 | return self.gather_tensors.group_label() 103 | 104 | 105 | class ModelStockMerge(MergeMethod): 106 | def name(self) -> str: 107 | return "model_stock" 108 | 109 | @override 110 | def pretty_name(self) -> Optional[str]: 111 | return "Model Stock" 112 | 113 | @override 114 | def reference_url(self): 115 | return "https://arxiv.org/abs/2403.19522" 116 | 117 | def parameters(self) -> List[ConfigParameterDef]: 118 | return [ 119 | ConfigParameterDef(name="filter_wise", required=False, default_value=False) 120 | ] 121 | 122 | def make_task( 123 | self, 124 | *, 125 | output_weight: WeightInfo, 126 | tensors: MergeTensorInput, 127 | base_model: Optional[ModelReference], 128 | parameters: ImmutableMap[str, Any], 129 | **_kwargs, 130 | ) -> Task: 131 | return ModelStockMergeTask( 132 | gather_tensors=tensors, 133 | base_model=base_model, 134 | weight_info=output_weight, 135 | filter_wise=parameters["filter_wise"], 136 | ) 137 | -------------------------------------------------------------------------------- /mergekit/merge_methods/multislerp.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from typing import List, Optional 5 | 6 | import torch 7 | 8 | from mergekit.merge_methods.easy_define import merge_method 9 | 10 | 11 | @merge_method( 12 | name="multislerp", 13 | pretty_name="Multi-SLERP", 14 | reference_url="https://goddard.blog/posts/multislerp-wow-what-a-cool-idea", 15 | ) 16 | def multislerp( 17 | tensors: List[torch.Tensor], 18 | weight: List[float], 19 | base_tensor: Optional[torch.Tensor] = None, 20 | normalize_weights: bool = True, 21 | eps: float = 1e-8, 22 | ): 23 | """ 24 | Implements barycentric interpolation on a hypersphere. 25 | 26 | The approach: 27 | 1. Project points onto a tangent space at their weighted Euclidean mean. 28 | 2. Perform the interpolation in the tangent space. 29 | 3. Project the result back to the hypersphere. 30 | 31 | Limitations: 32 | - The weighted sum of the input tensors must not be zero. 33 | - The tensors must not be all parallel or antiparallel. 34 | 35 | Args: 36 | tensors: List of tensors to interpolate 37 | weight: List of weights for each tensor 38 | base_tensor: Optional tensor defining the origin of the hypersphere 39 | normalize_weights: If True, the weights will be normalized to sum to 1 40 | eps: Small constant for numerical stability 41 | """ 42 | if len(tensors) == 1: 43 | # No interpolation needed 44 | return tensors[0] 45 | 46 | tensors = torch.stack(tensors, dim=0) 47 | if base_tensor is not None: 48 | tensors -= base_tensor 49 | 50 | tensors_flat = tensors.view(tensors.shape[0], -1) 51 | 52 | weights = torch.tensor(weight, dtype=tensors.dtype, device=tensors.device) 53 | if normalize_weights: 54 | weights = weights / weights.sum() 55 | 56 | # Project to unit hypersphere 57 | norms = torch.norm(tensors_flat, dim=-1, keepdim=True) 58 | unit_tensors = tensors_flat / (norms + eps) 59 | 60 | mean = (unit_tensors * weights.view(-1, 1)).sum(0) 61 | mean_norm = torch.norm(mean) 62 | if mean_norm < eps: 63 | if tensors.shape[0] == 2: 64 | # fallback to linear interpolation 65 | res = (tensors[0] * weights[0] + tensors[1] * weights[1]).view( 66 | tensors.shape[1:] 67 | ) 68 | if base_tensor is not None: 69 | res = res + base_tensor 70 | return res 71 | raise ValueError( 72 | "The weighted sum of the input tensors is zero. This occurs when " 73 | "antipodal vectors or sets of vectors have weights that exactly " 74 | "balance out (e.g., vectors a,-a with equal weights). Try using " 75 | "different weights if you have antipodal vectors." 76 | ) 77 | mean = mean / mean_norm 78 | 79 | # Project to tangent space 80 | dots = (unit_tensors * mean).sum(-1, keepdim=True) 81 | tangent_vectors = unit_tensors - dots * mean 82 | 83 | # Interpolate 84 | tangent_result = (tangent_vectors * weights.view(-1, 1)).sum(0) 85 | 86 | # Project back to sphere using exponential map 87 | tangent_norm = torch.norm(tangent_result) + eps 88 | result = mean * torch.cos(tangent_norm) + tangent_result * ( 89 | torch.sin(tangent_norm) / tangent_norm 90 | ) 91 | 92 | avg_norm = (norms.squeeze(-1) * weights).sum() 93 | result = result * avg_norm 94 | result = result.view(tensors.shape[1:]) 95 | 96 | if base_tensor is not None: 97 | result = result + base_tensor 98 | 99 | return result 100 | -------------------------------------------------------------------------------- /mergekit/merge_methods/nearswap.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from typing import List 5 | 6 | import torch 7 | 8 | from mergekit.merge_methods.easy_define import merge_method 9 | 10 | 11 | @merge_method( 12 | name="nearswap", 13 | pretty_name="NearSwap", 14 | reference_url="https://huggingface.co/alchemonaut/QuartetAnemoi-70B-t0.0001", 15 | ) 16 | def nearswap_merge( 17 | tensors: List[torch.Tensor], base_tensor: torch.Tensor, t: float 18 | ) -> torch.Tensor: 19 | if not tensors: 20 | return base_tensor 21 | if len(tensors) != 1: 22 | raise RuntimeError( 23 | "NearSwap merge expects exactly two models, one base and one other" 24 | ) 25 | a = base_tensor 26 | b = tensors[0] 27 | 28 | absdiff = torch.abs(a - b) 29 | weight = (t / absdiff.clamp(min=1e-6)).clamp(min=0, max=1) 30 | return weight * b + (1 - weight) * a 31 | -------------------------------------------------------------------------------- /mergekit/merge_methods/passthrough.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from typing import Any, Dict, List, Optional 5 | 6 | import torch 7 | from typing_extensions import override 8 | 9 | from mergekit.common import ImmutableMap, ModelReference 10 | from mergekit.graph import Task 11 | from mergekit.merge_methods.base import ( 12 | ConfigParameterDef, 13 | MergeMethod, 14 | MergeTensorInput, 15 | ) 16 | 17 | 18 | class PassthroughMergeTask(Task[torch.Tensor]): 19 | gather_tensors: MergeTensorInput 20 | tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] 21 | 22 | def arguments(self) -> Dict[str, Task]: 23 | return {"tensors": self.gather_tensors} 24 | 25 | def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor: 26 | if len(tensors) != 1: 27 | raise RuntimeError("Passthrough merge expects exactly one tensor") 28 | 29 | model, tensor = list(tensors.items())[0] 30 | scale = self.tensor_parameters[model].data.get("scale", None) 31 | if scale is not None: 32 | tensor = tensor * scale 33 | 34 | return tensor 35 | 36 | def group_label(self) -> Optional[str]: 37 | return self.gather_tensors.group_label() 38 | 39 | 40 | class PassthroughMerge(MergeMethod): 41 | def name(self) -> str: 42 | return "passthrough" 43 | 44 | @override 45 | def pretty_name(self) -> Optional[str]: 46 | return "Passthrough" 47 | 48 | def tensor_parameters(self) -> List[ConfigParameterDef]: 49 | return [ConfigParameterDef(name="scale", required=False, default_value=None)] 50 | 51 | def make_task( 52 | self, 53 | *, 54 | tensors: MergeTensorInput, 55 | tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], 56 | **kwargs, 57 | ) -> Task: 58 | return PassthroughMergeTask( 59 | gather_tensors=tensors, tensor_parameters=tensor_parameters 60 | ) 61 | -------------------------------------------------------------------------------- /mergekit/merge_methods/rectify_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | 5 | import logging 6 | from typing import List 7 | 8 | import torch 9 | 10 | from mergekit.architecture import WeightInfo 11 | 12 | 13 | def rectify_embed_sizes(weight_info: WeightInfo, tensors: List[torch.Tensor]): 14 | # TODO: use arch_info.embed_weights() instead 15 | if weight_info.is_embed and all(len(t.shape) == 2 for t in tensors): 16 | # special case - if lm_head.weight or embed_tokens.weight have a size 17 | # mismatch, take the largest common submatrix of all of them 18 | if take_common_submatrix(tensors): 19 | logging.warning( 20 | f"Using common submatrix of size {tensors[0].shape} for {weight_info.name}" 21 | ) 22 | 23 | 24 | def take_common_submatrix(tensors: List[torch.Tensor]) -> bool: 25 | min_size = [None, None] 26 | for t in tensors: 27 | for idx in range(2): 28 | if min_size[idx] is None or t.shape[idx] < min_size[idx]: 29 | min_size[idx] = t.shape[idx] 30 | 31 | if not all(t.shape == torch.Size(min_size) for t in tensors): 32 | for idx in range(len(tensors)): 33 | tensors[idx] = tensors[idx][: min_size[0], : min_size[1]] 34 | return True 35 | return False 36 | -------------------------------------------------------------------------------- /mergekit/merge_methods/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from typing import Dict, List 5 | 6 | from mergekit.merge_methods.arcee_fusion import ArceeFusionMerge 7 | from mergekit.merge_methods.base import MergeMethod 8 | from mergekit.merge_methods.generalized_task_arithmetic import ( 9 | ConsensusMethod, 10 | GeneralizedTaskArithmeticMerge, 11 | ) 12 | from mergekit.merge_methods.karcher import KarcherMerge 13 | from mergekit.merge_methods.linear import LinearMerge 14 | from mergekit.merge_methods.model_stock import ModelStockMerge 15 | from mergekit.merge_methods.nuslerp import NuSlerpMerge 16 | from mergekit.merge_methods.passthrough import PassthroughMerge 17 | from mergekit.merge_methods.slerp import SlerpMerge 18 | from mergekit.sparsify import SparsificationMethod 19 | 20 | STATIC_MERGE_METHODS: List[MergeMethod] = [ 21 | LinearMerge(), 22 | SlerpMerge(), 23 | NuSlerpMerge(), 24 | PassthroughMerge(), 25 | ModelStockMerge(), 26 | ArceeFusionMerge(), 27 | KarcherMerge(), 28 | # generalized task arithmetic methods 29 | GeneralizedTaskArithmeticMerge( 30 | consensus_method=None, 31 | sparsification_method=None, 32 | default_normalize=False, 33 | default_rescale=False, 34 | method_name="task_arithmetic", 35 | method_pretty_name="Task Arithmetic", 36 | method_reference_url="https://arxiv.org/abs/2212.04089", 37 | ), 38 | GeneralizedTaskArithmeticMerge( 39 | consensus_method=ConsensusMethod.sum, 40 | sparsification_method=SparsificationMethod.magnitude, 41 | default_normalize=True, 42 | default_rescale=False, 43 | method_name="ties", 44 | method_pretty_name="TIES", 45 | method_reference_url="https://arxiv.org/abs/2306.01708", 46 | ), 47 | GeneralizedTaskArithmeticMerge( 48 | consensus_method=ConsensusMethod.sum, 49 | sparsification_method=SparsificationMethod.random, 50 | default_normalize=False, 51 | default_rescale=True, 52 | method_name="dare_ties", 53 | method_pretty_name="DARE TIES", 54 | method_reference_url="https://arxiv.org/abs/2311.03099", 55 | ), 56 | GeneralizedTaskArithmeticMerge( 57 | consensus_method=None, 58 | sparsification_method=SparsificationMethod.random, 59 | default_normalize=False, 60 | default_rescale=True, 61 | method_name="dare_linear", 62 | method_pretty_name="Linear DARE", 63 | method_reference_url="https://arxiv.org/abs/2311.03099", 64 | ), 65 | GeneralizedTaskArithmeticMerge( 66 | consensus_method=None, 67 | sparsification_method=SparsificationMethod.magnitude_outliers, 68 | default_normalize=False, 69 | default_rescale=False, 70 | method_name="breadcrumbs", 71 | method_pretty_name="Model Breadcrumbs", 72 | method_reference_url="https://arxiv.org/abs/2312.06795", 73 | ), 74 | GeneralizedTaskArithmeticMerge( 75 | consensus_method=ConsensusMethod.sum, 76 | sparsification_method=SparsificationMethod.magnitude_outliers, 77 | default_normalize=False, 78 | default_rescale=False, 79 | method_name="breadcrumbs_ties", 80 | method_pretty_name="Model Breadcrumbs with TIES", 81 | method_reference_url="https://arxiv.org/abs/2312.06795", 82 | ), 83 | GeneralizedTaskArithmeticMerge( 84 | consensus_method=ConsensusMethod.sum, 85 | sparsification_method=SparsificationMethod.della_magprune, 86 | default_normalize=True, 87 | default_rescale=True, 88 | method_name="della", 89 | method_pretty_name="DELLA", 90 | method_reference_url="https://arxiv.org/abs/2406.11617", 91 | ), 92 | GeneralizedTaskArithmeticMerge( 93 | consensus_method=None, 94 | sparsification_method=SparsificationMethod.della_magprune, 95 | default_normalize=False, 96 | default_rescale=True, 97 | method_name="della_linear", 98 | method_pretty_name="Linear DELLA", 99 | method_reference_url="https://arxiv.org/abs/2406.11617", 100 | ), 101 | ] 102 | 103 | REGISTERED_MERGE_METHODS: Dict[str, MergeMethod] = { 104 | method.name(): method for method in STATIC_MERGE_METHODS 105 | } 106 | -------------------------------------------------------------------------------- /mergekit/merge_methods/sce.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from typing import List, Optional 5 | 6 | import torch 7 | 8 | from mergekit.merge_methods.easy_define import merge_method 9 | from mergekit.merge_methods.generalized_task_arithmetic import ( 10 | get_mask as sign_consensus_mask, 11 | ) 12 | 13 | 14 | @merge_method( 15 | name="sce", 16 | pretty_name="SCE", 17 | reference_url="https://arxiv.org/abs/2408.07990", 18 | ) 19 | def sce_merge( 20 | tensors: List[torch.Tensor], 21 | base_tensor: torch.Tensor, 22 | int8_mask: bool = False, 23 | select_topk: float = 1.0, 24 | ) -> torch.Tensor: 25 | if not tensors: 26 | return base_tensor 27 | mask_dtype = torch.int8 if int8_mask else base_tensor.dtype 28 | task_vectors = torch.stack([t - base_tensor for t in tensors], dim=0) 29 | 30 | if select_topk < 1: 31 | mask = sce_mask(task_vectors, select_topk, mask_dtype) 32 | task_vectors = task_vectors * mask.unsqueeze(0) 33 | 34 | erase_mask = sign_consensus_mask(task_vectors, method="sum", mask_dtype=mask_dtype) 35 | 36 | tv_weights = sce_weight(task_vectors) 37 | while tv_weights.dim() < task_vectors.dim(): 38 | tv_weights = tv_weights.unsqueeze(-1) 39 | 40 | erased_weights = tv_weights * erase_mask 41 | merged_tv = (task_vectors * erased_weights).sum(dim=0) 42 | final_tv = merged_tv / torch.sum(erased_weights, dim=0).clamp(min=1e-6) 43 | 44 | return base_tensor + final_tv 45 | 46 | 47 | def sce_weight(tvs: torch.Tensor) -> torch.Tensor: 48 | weights = torch.mean(tvs**2, dim=list(range(1, tvs.dim()))) 49 | weight_sum = torch.sum(weights).item() 50 | if abs(weight_sum) < 1e-6: 51 | return torch.ones_like(weights) / weights.shape[0] 52 | return weights / weight_sum 53 | 54 | 55 | def sce_mask( 56 | tvs: torch.Tensor, density: float, mask_dtype: Optional[torch.dtype] = None 57 | ): 58 | if density <= 0: 59 | return torch.zeros_like(tvs, dtype=mask_dtype) 60 | if density >= 1: 61 | return torch.ones_like(tvs, dtype=mask_dtype) 62 | 63 | var = torch.var(tvs, dim=0, unbiased=False) 64 | nonzero = torch.count_nonzero(var) 65 | k = int(nonzero * density) 66 | if k == 0: 67 | return torch.zeros_like(tvs, dtype=mask_dtype) 68 | 69 | _, indices = torch.topk(var.abs().view(-1), k=k, largest=True) 70 | mask = torch.zeros_like(var, dtype=mask_dtype) 71 | mask.view(-1)[indices] = 1 72 | return mask 73 | -------------------------------------------------------------------------------- /mergekit/moe/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from mergekit.moe.arch import MoEOutputArchitecture 4 | from mergekit.moe.deepseek import DeepseekMoE 5 | from mergekit.moe.mixtral import MixtralMoE 6 | 7 | ALL_OUTPUT_ARCHITECTURES: List[MoEOutputArchitecture] = [MixtralMoE(), DeepseekMoE()] 8 | 9 | try: 10 | from mergekit.moe.qwen import QwenMoE 11 | except ImportError: 12 | pass 13 | else: 14 | ALL_OUTPUT_ARCHITECTURES.append(QwenMoE()) 15 | 16 | __all__ = [ 17 | "ALL_OUTPUT_ARCHITECTURES", 18 | "MoEOutputArchitecture", 19 | ] 20 | -------------------------------------------------------------------------------- /mergekit/moe/arch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from abc import ABC, abstractmethod 5 | from typing import List, Optional 6 | 7 | import torch 8 | 9 | from mergekit.moe.config import MoEMergeConfig 10 | from mergekit.options import MergeOptions 11 | 12 | 13 | class MoEOutputArchitecture(ABC): 14 | @abstractmethod 15 | def name(self) -> str: 16 | """Return a human-readable name for the architecture.""" 17 | pass 18 | 19 | @abstractmethod 20 | def supports_config( 21 | self, 22 | config: MoEMergeConfig, 23 | explain: bool = False, 24 | trust_remote_code: bool = False, 25 | ) -> bool: 26 | """Return whether this architecture supports the given config. 27 | 28 | If `explain` is True, log an explanation of why the config is not supported.""" 29 | pass 30 | 31 | @abstractmethod 32 | def write_model( 33 | self, 34 | out_path: str, 35 | config: MoEMergeConfig, 36 | merge_options: MergeOptions, 37 | router_weights: List[torch.Tensor], 38 | shared_router_weights: Optional[List[torch.Tensor]] = None, 39 | ): 40 | """Write the config and tensors for the output MoE to the given path.""" 41 | pass 42 | -------------------------------------------------------------------------------- /mergekit/moe/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | import logging 5 | from typing import Dict, Optional, Tuple 6 | 7 | import torch 8 | import tqdm 9 | import transformers 10 | 11 | from mergekit.architecture import WeightInfo 12 | from mergekit.common import ModelReference, dtype_from_name 13 | from mergekit.io import LazyTensorLoader, TensorWriter 14 | from mergekit.merge import MergeOptions 15 | from mergekit.moe.config import Expert, MoEMergeConfig 16 | 17 | 18 | def initialize_io( 19 | config: MoEMergeConfig, 20 | out_path: str, 21 | merge_options: MergeOptions, 22 | ) -> Tuple[Dict[ModelReference, LazyTensorLoader], LazyTensorLoader, TensorWriter]: 23 | base_model = config.base_model 24 | loaders: Dict[ModelReference, LazyTensorLoader] = {} 25 | for model in tqdm.tqdm( 26 | [base_model] + [e.source_model for e in config.experts], desc="Warm up loaders" 27 | ): 28 | loaders[model] = model.lazy_loader( 29 | cache_dir=merge_options.transformers_cache, 30 | lazy_unpickle=merge_options.lazy_unpickle, 31 | ) 32 | 33 | base_loader = loaders.get(base_model) 34 | writer = TensorWriter( 35 | out_path=out_path, 36 | max_shard_size=merge_options.out_shard_size, 37 | safe_serialization=merge_options.safe_serialization, 38 | ) 39 | 40 | return loaders, base_loader, writer 41 | 42 | 43 | def select_dtype( 44 | config: MoEMergeConfig, base_cfg: transformers.PretrainedConfig 45 | ) -> Optional[torch.dtype]: 46 | out_dtype = None 47 | if config.dtype: 48 | out_dtype = dtype_from_name(config.dtype) 49 | 50 | if out_dtype is None and base_cfg.torch_dtype: 51 | out_dtype = base_cfg.torch_dtype 52 | if isinstance(out_dtype, str): 53 | out_dtype = dtype_from_name(out_dtype) 54 | return out_dtype 55 | 56 | 57 | def noise_and_scale( 58 | tensor: torch.Tensor, expert: Expert, is_residual: bool = False 59 | ) -> torch.Tensor: 60 | if expert.noise_scale is not None: 61 | noise = torch.randn_like(tensor) * expert.noise_scale 62 | tensor = tensor + noise 63 | if is_residual and expert.residual_scale is not None: 64 | tensor = tensor * expert.residual_scale 65 | return tensor 66 | 67 | 68 | def copy_tensor_out( 69 | weight_info: WeightInfo, 70 | loader: LazyTensorLoader, 71 | writer: TensorWriter, 72 | expert: Optional[Expert] = None, 73 | is_residual: bool = False, 74 | output_name: Optional[str] = None, 75 | out_dtype: Optional[torch.dtype] = None, 76 | clone: bool = False, 77 | ): 78 | out_tensor_name = output_name or weight_info.name 79 | try: 80 | tensor = loader.get_tensor(weight_info.name, aliases=weight_info.aliases) 81 | except KeyError: 82 | tensor = None 83 | if tensor is None and not weight_info.optional: 84 | logging.error(f"Missing weight: {weight_info.name} / {out_tensor_name}") 85 | raise KeyError(out_tensor_name) 86 | 87 | if expert: 88 | tensor = noise_and_scale(tensor, expert, is_residual=is_residual) 89 | writer.save_tensor( 90 | out_tensor_name, 91 | tensor.to(dtype=out_dtype), 92 | clone=clone, 93 | ) 94 | -------------------------------------------------------------------------------- /mergekit/moe/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | import logging 5 | from typing import List, Optional 6 | 7 | from pydantic import BaseModel 8 | 9 | from mergekit.common import ModelReference 10 | 11 | 12 | class Expert(BaseModel): 13 | """ 14 | Defines a model to be used as a set of layerwise experts in a MoE model. 15 | """ 16 | 17 | source_model: ModelReference 18 | 19 | positive_prompts: Optional[List[str]] = None 20 | negative_prompts: Optional[List[str]] = None 21 | noise_scale: Optional[float] = None 22 | residual_scale: Optional[float] = None 23 | 24 | 25 | class MoEMergeConfig(BaseModel): 26 | """ 27 | Configuration for merging a set of "expert" models into a MoE model. 28 | """ 29 | 30 | base_model: ModelReference 31 | experts: List[Expert] 32 | gate_mode: str = ( 33 | "hidden" # possible values: "hidden", "cheap_embed", "random", "uniform_random" 34 | ) 35 | # "hidden" uses hidden state vectors for the given prompts for each layer 36 | # "cheap_embed" uses the average of token embeddings for the prompts, same for each layer 37 | # "random" is random 38 | # "uniform_random" matches default initialization for torch.nn.Linear 39 | dtype: Optional[str] = None 40 | experts_per_token: int = 2 41 | shared_experts: Optional[List[Expert]] = None 42 | architecture: Optional[str] = None 43 | 44 | 45 | def is_bad_config(config: MoEMergeConfig, allow_all_same: bool = False) -> bool: 46 | if config.experts_per_token < 1: 47 | logging.error("Experts per token must be >= 1") 48 | return True 49 | 50 | if len(config.experts) < config.experts_per_token: 51 | logging.error("Must include at least as many experts as experts_per_token.") 52 | return True 53 | 54 | if config.gate_mode == "random": 55 | return False # eh we're good 56 | 57 | for expert_idx, expert in enumerate(config.experts): 58 | if not expert.positive_prompts: 59 | logging.error(f"Expert {expert_idx} has no positive prompts.") 60 | return True 61 | 62 | def prompt_tup(e: Expert): 63 | return (tuple(e.positive_prompts), tuple(e.negative_prompts or [])) 64 | 65 | # let's just nip this trend in the bud 66 | p_first = prompt_tup(config.experts[0]) 67 | if all(prompt_tup(e) == p_first for e in config.experts[1:]): 68 | logging.error( 69 | "Your positive and negative prompts are identical for all experts. This will not produce a functioning MoE." 70 | ) 71 | logging.error( 72 | "For each expert, `positive_prompts` must contain one or more example prompt reflecting what should be routed to that expert." 73 | ) 74 | return True 75 | 76 | if not allow_all_same: 77 | if all( 78 | e.source_model == config.experts[0].source_model for e in config.experts[1:] 79 | ): 80 | logging.error( 81 | "All of your expert models are the same. This will produce " 82 | "a model that uses more resources but gives the exact same output. " 83 | "If you plan to train the model after merging, proceed with the " 84 | "--i-understand-this-is-not-useful-without-training flag." 85 | ) 86 | return True 87 | -------------------------------------------------------------------------------- /mergekit/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arcee-ai/mergekit/488957e8e67c82861ecf63ef761f6bc59122dc74/mergekit/scripts/__init__.py -------------------------------------------------------------------------------- /mergekit/scripts/bakllama.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from typing import List, Optional 5 | 6 | import click 7 | import yaml 8 | from pydantic import BaseModel 9 | 10 | from mergekit.common import MergeOptions 11 | from mergekit.config import ( 12 | ConditionalParameter, 13 | InputSliceDefinition, 14 | MergeConfiguration, 15 | ) 16 | from mergekit.merge import run_merge 17 | 18 | 19 | class LayerSlice(BaseModel): 20 | model: str 21 | start: int 22 | end: int 23 | scale: Optional[float] = None 24 | 25 | 26 | class BakllamaConfig(BaseModel): 27 | layer_slices: List[LayerSlice] 28 | embedding_source: Optional[str] = None 29 | lm_head_source: Optional[str] = None 30 | 31 | 32 | @click.command("bakllama") 33 | @click.argument("config_path", type=click.Path(exists=True, dir_okay=False)) 34 | @click.argument("out_path", type=str) 35 | @click.option( 36 | "--clone-tensors/--no-clone-tensors", 37 | type=bool, 38 | is_flag=True, 39 | help="Clone tensors before saving, to allow multiple occurrences of the same layer", 40 | default=False, 41 | ) 42 | @click.option("--fp16/--no-fp16", type=bool, default=False) 43 | def main( 44 | config_path: str, 45 | out_path: str, 46 | clone_tensors: bool, 47 | fp16: bool, 48 | ): 49 | """Wrapper for using legacy bakllama configuration files.""" 50 | with open(config_path, "r", encoding="utf-8") as file: 51 | config = BakllamaConfig.model_validate(yaml.safe_load(file)) 52 | 53 | slices = [] 54 | for s in config.layer_slices: 55 | parameters = {} 56 | if s.scale is not None: 57 | parameters["scale"] = ConditionalParameter( 58 | value=s.scale, filter="down_proj" 59 | ) 60 | slices.append( 61 | InputSliceDefinition( 62 | model=s.model, layer_range=(s.start, s.end), parameters=parameters 63 | ) 64 | ) 65 | 66 | merge_config = MergeConfiguration( 67 | merge_method="passthrough", slices=slices, dtype="float16" if fp16 else None 68 | ) 69 | run_merge(merge_config, out_path, MergeOptions(clone_tensors=clone_tensors)) 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /mergekit/scripts/layershuffle.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | import random 5 | from typing import List 6 | 7 | import click 8 | import yaml 9 | 10 | from mergekit.architecture import arch_info_for_config 11 | from mergekit.common import ModelReference 12 | from mergekit.config import ( 13 | InputSliceDefinition, 14 | MergeConfiguration, 15 | OutputSliceDefinition, 16 | ) 17 | from mergekit.merge import run_merge 18 | from mergekit.options import MergeOptions, PrettyPrintHelp, add_merge_options 19 | 20 | 21 | @click.command("mergekit-layershuffle", cls=PrettyPrintHelp) 22 | @click.argument("out_path", type=str) 23 | @click.option("--model", "-m", multiple=True, type=str, help="Add a model to the merge") 24 | @click.option( 25 | "--weight", 26 | "-w", 27 | multiple=True, 28 | type=float, 29 | default=[], 30 | show_default=False, 31 | help="Weighting for a model", 32 | ) 33 | @click.option( 34 | "--print-yaml/--no-print-yaml", 35 | is_flag=True, 36 | help="Print YAML merge config for resulting model", 37 | ) 38 | @click.option( 39 | "--write-yaml", 40 | type=click.Path(writable=True), 41 | help="Path to write YAML merge config to", 42 | ) 43 | @click.option( 44 | "--dry-run", is_flag=True, help="Generate a config but do not run the merge" 45 | ) 46 | @click.option("--fp16/--no-fp16", is_flag=True, help="Use FP16 precision") 47 | @click.option( 48 | "--full-random/--no-full-random", 49 | is_flag=True, 50 | help="Randomize layer index as well as source model", 51 | ) 52 | @add_merge_options 53 | def main( 54 | out_path: str, 55 | model: List[str], 56 | weight: List[float], 57 | print_yaml: bool, 58 | write_yaml: bool, 59 | dry_run: bool, 60 | fp16: bool, 61 | full_random: bool, 62 | merge_options: MergeOptions, 63 | ): 64 | models = [ModelReference.parse(m) for m in model] 65 | 66 | m0_cfg = models[0].config() 67 | arch_info = arch_info_for_config(m0_cfg) 68 | total_num_layers = arch_info.num_layers(m0_cfg) 69 | 70 | out_slices: List[OutputSliceDefinition] = [] 71 | 72 | if full_random: 73 | for model, frac in zip(models, weight): 74 | cfg = model.config() 75 | num_layers = int(arch_info.num_layers(cfg) * frac) 76 | for _ in range(num_layers): 77 | src_idx = random.randrange(0, num_layers) 78 | out_slices.append( 79 | OutputSliceDefinition( 80 | sources=[ 81 | InputSliceDefinition( 82 | model=str(model), 83 | layer_range=(src_idx, src_idx + 1), 84 | ) 85 | ] 86 | ) 87 | ) 88 | random.shuffle(out_slices) 89 | else: 90 | for layer_idx in range(total_num_layers): 91 | src_model = random.choices(models, weights=weight, k=1)[0] 92 | if out_slices and out_slices[-1].sources[0].model == str(src_model): 93 | out_slices[-1].sources[0].layer_range = ( 94 | out_slices[-1].sources[0].layer_range[0], 95 | layer_idx + 1, 96 | ) 97 | else: 98 | out_slices.append( 99 | OutputSliceDefinition( 100 | sources=[ 101 | InputSliceDefinition( 102 | model=str(src_model), 103 | layer_range=(layer_idx, layer_idx + 1), 104 | ) 105 | ] 106 | ) 107 | ) 108 | merge_config = MergeConfiguration( 109 | merge_method="passthrough", slices=out_slices, dtype="float16" if fp16 else None 110 | ) 111 | 112 | if print_yaml or write_yaml: 113 | yaml_str = yaml.dump(merge_config.model_dump(exclude_none=True, mode="json")) 114 | 115 | if print_yaml: 116 | print(yaml_str) 117 | if write_yaml: 118 | with open(write_yaml, "w", encoding="utf-8") as file: 119 | file.write(yaml_str) 120 | 121 | if dry_run: 122 | return 123 | 124 | run_merge( 125 | merge_config, 126 | out_path, 127 | options=merge_options, 128 | ) 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /mergekit/scripts/legacy.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from typing import List, Optional 5 | 6 | import click 7 | import yaml 8 | 9 | from mergekit.config import InputModelDefinition, MergeConfiguration 10 | from mergekit.merge import run_merge 11 | from mergekit.options import MergeOptions, PrettyPrintHelp, add_merge_options 12 | 13 | 14 | @click.command("mergekit-legacy", cls=PrettyPrintHelp) 15 | @click.argument("out_path", type=str) 16 | @click.option( 17 | "--merge", "merge", type=str, multiple=True, help="Add a model to the merge" 18 | ) 19 | @click.option( 20 | "--density", 21 | "density", 22 | type=float, 23 | multiple=True, 24 | default=[], 25 | help="Fraction of weights to keep for each model (ties only)", 26 | ) 27 | @click.option( 28 | "--weight", 29 | "weight", 30 | type=float, 31 | multiple=True, 32 | default=[], 33 | help="Weighting for a model (default 1.0 for all models if not specified)", 34 | ) 35 | @click.option( 36 | "--method", "method", type=str, default="ties", help="Method used to merge models" 37 | ) 38 | @click.option( 39 | "--base-model", "base_model", type=str, default=None, help="Base model for merge" 40 | ) 41 | @click.option( 42 | "--normalize/--no-normalize", 43 | "normalize", 44 | is_flag=True, 45 | default=True, 46 | help="Divide merged parameters by the sum of weights", 47 | ) 48 | @click.option( 49 | "--int8-mask/--no-int8-mask", 50 | "int8_mask", 51 | is_flag=True, 52 | help="Store intermediate masks in int8 to save memory", 53 | ) 54 | @click.option("--bf16/--no-bf16", "bf16", is_flag=True, help="Use bfloat16") 55 | @click.option( 56 | "--naive-count/--no-naive-count", 57 | "naive_count", 58 | is_flag=True, 59 | help="Use naive sign count instead of weight (ties only)", 60 | ) 61 | @click.option( 62 | "--print-yaml/--no-print-yaml", 63 | "print_yaml", 64 | is_flag=True, 65 | help="Print generated YAML configuration", 66 | ) 67 | @add_merge_options 68 | def main( 69 | out_path: str, 70 | merge: List[str], 71 | density: List[float], 72 | weight: List[float], 73 | method: str, 74 | base_model: Optional[str], 75 | normalize: bool, 76 | int8_mask: bool, 77 | bf16: bool, 78 | naive_count: bool, 79 | print_yaml: bool, 80 | merge_options: MergeOptions, 81 | ): 82 | """Wrapper for using a subset of legacy-style script arguments.""" 83 | models = [InputModelDefinition(model=model, parameters={}) for model in merge] 84 | if base_model and base_model not in merge: 85 | models.append(InputModelDefinition(model=base_model, parameters={})) 86 | 87 | parameters = {} 88 | 89 | if density: 90 | if len(density) == 1: 91 | density = [density[0]] * len(models) 92 | for idx, d in enumerate(density): 93 | models[idx].parameters["density"] = d 94 | 95 | if method == "slerp": 96 | assert len(weight) == 1, "Must specify exactly one weight for SLERP" 97 | parameters["t"] = weight[0] 98 | else: 99 | if weight: 100 | if len(weight) == 1: 101 | weight = [weight[0]] * len(models) 102 | for idx, w in enumerate(weight): 103 | models[idx].parameters["weight"] = w 104 | 105 | if int8_mask: 106 | parameters["int8_mask"] = True 107 | if naive_count: 108 | parameters["consensus_method"] = "count" 109 | parameters["normalize"] = normalize 110 | 111 | merge_config = MergeConfiguration( 112 | merge_method=method, 113 | models=models, 114 | parameters=parameters, 115 | base_model=base_model, 116 | dtype="bfloat16" if bf16 else None, 117 | ) 118 | 119 | if print_yaml: 120 | print(yaml.dump(merge_config.model_dump(mode="json", exclude_none=True))) 121 | 122 | run_merge( 123 | merge_config, 124 | out_path, 125 | options=merge_options, 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /mergekit/scripts/run_yaml.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | 5 | import click 6 | import yaml 7 | 8 | from mergekit.config import MergeConfiguration 9 | from mergekit.merge import run_merge 10 | from mergekit.options import MergeOptions, PrettyPrintHelp, add_merge_options 11 | 12 | 13 | @click.command("mergekit-yaml", cls=PrettyPrintHelp) 14 | @click.argument("config_file") 15 | @click.argument("out_path") 16 | @add_merge_options 17 | def main( 18 | merge_options: MergeOptions, 19 | config_file: str, 20 | out_path: str, 21 | ): 22 | merge_options.apply_global_options() 23 | 24 | with open(config_file, "r", encoding="utf-8") as file: 25 | config_source = file.read() 26 | 27 | merge_config: MergeConfiguration = MergeConfiguration.model_validate( 28 | yaml.safe_load(config_source) 29 | ) 30 | run_merge( 31 | merge_config, 32 | out_path, 33 | options=merge_options, 34 | config_source=config_source, 35 | ) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /mergekit/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from mergekit.tokenizer.build import BuildTokenizer, TokenizerInfo 5 | from mergekit.tokenizer.config import TokenizerConfig 6 | from mergekit.tokenizer.embed import PermutedEmbeddings 7 | 8 | __all__ = ["BuildTokenizer", "TokenizerInfo", "TokenizerConfig", "PermutedEmbeddings"] 9 | -------------------------------------------------------------------------------- /mergekit/tokenizer/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Arcee AI 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | from typing import Dict, Optional, Union 5 | 6 | import pydantic 7 | from pydantic import BaseModel 8 | from typing_extensions import Literal 9 | 10 | from mergekit.common import ModelReference 11 | 12 | 13 | class ModelTokenEmbedding(BaseModel, frozen=True): 14 | kind: Literal["model_token"] 15 | model: ModelReference 16 | token_id: Optional[int] = None 17 | token: Optional[str] = None 18 | 19 | @pydantic.model_validator(mode="after") 20 | def validate_token(self): 21 | if self.token_id is None and self.token is None: 22 | raise ValueError("token_id or token must be specified") 23 | if self.token_id is not None and self.token is not None: 24 | raise ValueError("only one of token_id or token may be specified") 25 | return self 26 | 27 | 28 | class ZeroEmbedding(BaseModel, frozen=True): 29 | kind: Literal["zero"] 30 | 31 | 32 | class TokenEmbeddingConfig(BaseModel, frozen=True): 33 | source: Union[ModelTokenEmbedding, ZeroEmbedding, ModelReference, None] = None 34 | force: bool = False 35 | 36 | 37 | class TokenizerConfig(BaseModel, frozen=True): 38 | source: Union[ModelReference, Literal["union"], Literal["base"]] = "union" 39 | tokens: Optional[Dict[str, TokenEmbeddingConfig]] = None 40 | pad_to_multiple_of: Optional[int] = None 41 | -------------------------------------------------------------------------------- /notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true, 8 | "id": "cmjOVVtJdiPZ" 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "!git clone https://github.com/cg123/mergekit.git\n", 13 | "%cd mergekit\n", 14 | "%pip install -e ." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": { 21 | "id": "84cRJT6_ecbw" 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "OUTPUT_PATH = \"./merged\" # folder to store the result in\n", 26 | "LORA_MERGE_CACHE = \"/tmp\" # change if you want to keep these for some reason\n", 27 | "CONFIG_YML = \"./examples/gradient-slerp.yml\" # merge configuration file\n", 28 | "COPY_TOKENIZER = True # you want a tokenizer? yeah, that's what i thought\n", 29 | "LAZY_UNPICKLE = False # experimental low-memory model loader\n", 30 | "LOW_CPU_MEMORY = False # enable if you somehow have more VRAM than RAM+swap" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": { 37 | "id": "6nw26xQLkBax" 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "# actually do merge\n", 42 | "import torch\n", 43 | "import yaml\n", 44 | "\n", 45 | "from mergekit.config import MergeConfiguration\n", 46 | "from mergekit.merge import MergeOptions, run_merge\n", 47 | "\n", 48 | "with open(CONFIG_YML, \"r\", encoding=\"utf-8\") as fp:\n", 49 | " merge_config = MergeConfiguration.model_validate(yaml.safe_load(fp))\n", 50 | "\n", 51 | "run_merge(\n", 52 | " merge_config,\n", 53 | " out_path=OUTPUT_PATH,\n", 54 | " options=MergeOptions(\n", 55 | " lora_merge_cache=LORA_MERGE_CACHE,\n", 56 | " cuda=torch.cuda.is_available(),\n", 57 | " copy_tokenizer=COPY_TOKENIZER,\n", 58 | " lazy_unpickle=LAZY_UNPICKLE,\n", 59 | " low_cpu_memory=LOW_CPU_MEMORY,\n", 60 | " ),\n", 61 | ")\n", 62 | "print(\"Done!\")" 63 | ] 64 | } 65 | ], 66 | "metadata": { 67 | "accelerator": "GPU", 68 | "colab": { 69 | "gpuType": "T4", 70 | "provenance": [] 71 | }, 72 | "kernelspec": { 73 | "display_name": "Python 3", 74 | "name": "python3" 75 | }, 76 | "language_info": { 77 | "codemirror_mode": { 78 | "name": "ipython", 79 | "version": 3 80 | }, 81 | "file_extension": ".py", 82 | "mimetype": "text/x-python", 83 | "name": "python", 84 | "nbconvert_exporter": "python", 85 | "pygments_lexer": "ipython3", 86 | "version": "3.10.13" 87 | } 88 | }, 89 | "nbformat": 4, 90 | "nbformat_minor": 0 91 | } 92 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "mergekit" 7 | description = "Tools for merging pre-trained large language models" 8 | readme = "README.md" 9 | license = { text = "BUSL-1.1" } 10 | version = "0.1.2" 11 | authors = [{ name = "Charles Goddard", email = "chargoddard@gmail.com" }] 12 | requires-python = ">=3.10" 13 | dependencies = [ 14 | "torch>=2.0.0", 15 | "tqdm==4.67.1", 16 | "click==8.1.8", 17 | "safetensors~=0.5.2", 18 | "accelerate~=1.3.0", 19 | "pydantic~=2.10.6", 20 | "immutables==0.21", 21 | "transformers>=4.45.2", 22 | "tokenizers>=0.20.1", 23 | "huggingface_hub", 24 | "peft", 25 | "typing-extensions", 26 | "sentencepiece", 27 | "protobuf", 28 | "scipy", 29 | "datasets", 30 | ] 31 | 32 | [project.optional-dependencies] 33 | dev = ["black~=24.10.0", "isort~=5.13.2", "pre-commit~=4.1.0"] 34 | test = ["pytest~=8.3.4"] 35 | evolve = ["ray", "cma", "lm_eval", "wandb"] 36 | vllm = ["vllm==0.7.2", "lm_eval[vllm]"] 37 | 38 | [project.urls] 39 | repository = "https://github.com/cg123/mergekit" 40 | 41 | 42 | [project.scripts] 43 | mergekit-yaml = "mergekit.scripts.run_yaml:main" 44 | mergekit-legacy = "mergekit.scripts.legacy:main" 45 | mergekit-layershuffle = "mergekit.scripts.layershuffle:main" 46 | bakllama = "mergekit.scripts.bakllama:main" 47 | mergekit-moe = "mergekit.scripts.moe:main" 48 | mergekit-tokensurgeon = "mergekit.scripts.tokensurgeon:main" 49 | mergekit-extract-lora = "mergekit.scripts.extract_lora:main" 50 | mergekit-evolve = "mergekit.scripts.evolve:main" 51 | mergekit-pytorch = "mergekit.scripts.merge_raw_pytorch:main" 52 | mergekit-multi = "mergekit.scripts.multimerge:main" 53 | 54 | [tool.setuptools] 55 | packages = [ 56 | "mergekit", 57 | "mergekit.io", 58 | "mergekit.merge_methods", 59 | "mergekit.moe", 60 | "mergekit.scripts", 61 | "mergekit.evo", 62 | "mergekit.tokenizer", 63 | "mergekit.architecture", 64 | "mergekit._data", 65 | "mergekit._data.architectures", 66 | "mergekit._data.chat_templates", 67 | ] 68 | include-package-data = true 69 | package-data = { "mergekit._data.architectures" = [ 70 | "*.json", 71 | ], "mergekit._data.chat_templates" = [ 72 | "*.jinja", 73 | ] } 74 | 75 | [tool.isort] 76 | profile = "black" 77 | 78 | [tool.black] 79 | line-length = 88 80 | target-version = ['py37'] 81 | include = '\.pyi?$' 82 | 83 | [tool.pytest.ini_options] 84 | minversion = "6.0" 85 | filterwarnings = [ 86 | "ignore::pydantic.PydanticDeprecatedSince20:huggingface_hub.*:", 87 | "ignore::FutureWarning:huggingface_hub.*:", 88 | "ignore:Attempting Automatic Merge:UserWarning", 89 | "ignore:No architecture config available:UserWarning", 90 | ] 91 | testpaths = ["tests"] 92 | 93 | [dependency-groups] 94 | test = ["pytest~=8.3.5"] 95 | dev = ["black~=24.10.0", "isort~=5.13.2", "pre-commit~=4.1.0"] 96 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arcee-ai/mergekit/488957e8e67c82861ecf63ef761f6bc59122dc74/tests/__init__.py -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from typing import Callable, Optional 4 | 5 | from transformers import ( 6 | AutoConfig, 7 | CLIPVisionConfig, 8 | GPT2Config, 9 | GPT2LMHeadModel, 10 | LlamaConfig, 11 | LlamaForCausalLM, 12 | LlavaConfig, 13 | LlavaForConditionalGeneration, 14 | ) 15 | 16 | from mergekit.architecture import ( 17 | arch_info_for_config, 18 | get_architecture_info, 19 | ) 20 | from mergekit.config import MergeConfiguration 21 | from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex 22 | from mergekit.merge import MergeOptions, run_merge 23 | 24 | 25 | def run_and_check_merge( 26 | config: MergeConfiguration, 27 | check_nan: bool = True, 28 | check_tensors: bool = True, 29 | validate: Optional[Callable[[str], None]] = None, 30 | index_json_name: Optional[str] = None, 31 | auto_arch: bool = False, 32 | ): 33 | if index_json_name is None: 34 | index_json_name = "model.safetensors.index.json" 35 | 36 | with tempfile.TemporaryDirectory() as tmpdir: 37 | run_merge(config, out_path=tmpdir, options=MergeOptions()) 38 | index_path = os.path.join(tmpdir, index_json_name) 39 | index_exists = os.path.exists(index_path) 40 | single_shard_exists = os.path.exists(index_path.replace(".index.json", "")) 41 | assert index_exists or single_shard_exists, "No model produced by merge" 42 | assert os.path.exists( 43 | os.path.join(tmpdir, "config.json") 44 | ), "No config json produced by merge" 45 | 46 | if check_nan: 47 | # check for NaN in output 48 | loader = LazyTensorLoader.from_disk(tmpdir, lazy_unpickle=False) 49 | tp = loader.index.tensor_paths 50 | sorted_tensors = sorted(tp.keys(), key=lambda k: tp[k]) 51 | for tensor_name in sorted_tensors: 52 | tensor = loader.get_tensor(tensor_name) 53 | has_nan = tensor.view(-1).isnan().any() 54 | assert not has_nan, "Output contains NaN" 55 | 56 | if check_tensors: 57 | model_config = AutoConfig.from_pretrained(tmpdir) 58 | if auto_arch: 59 | arch_info = get_architecture_info(config, MergeOptions()) 60 | else: 61 | arch_info = arch_info_for_config(model_config) 62 | 63 | index = ShardedTensorIndex.from_disk(tmpdir) 64 | for weight_info in arch_info.all_weights(model_config): 65 | if weight_info.optional: 66 | continue 67 | if weight_info.name not in index.tensor_paths and not any( 68 | a in index.tensor_paths for a in weight_info.aliases 69 | ): 70 | raise RuntimeError(f"Output missing tensor {weight_info.name}") 71 | 72 | if validate: 73 | validate(tmpdir) 74 | 75 | 76 | def make_picollama(path: str, vocab_size: int = 64): 77 | cfg = LlamaConfig( 78 | vocab_size=vocab_size, 79 | hidden_size=32, 80 | intermediate_size=48, 81 | num_attention_heads=16, 82 | num_hidden_layers=2, 83 | ) 84 | model = LlamaForCausalLM(cfg) 85 | model.save_pretrained(path, safe_serialization=True) 86 | return str(path) 87 | 88 | 89 | def make_gpt2size(path: str): 90 | cfg = GPT2Config( 91 | n_ctx=1024, 92 | n_embd=768, 93 | n_head=12, 94 | n_layer=12, 95 | n_positions=1024, 96 | vocab_size=50257, 97 | ) 98 | model = GPT2LMHeadModel(cfg) 99 | model.save_pretrained(path, safe_serialization=True) 100 | return str(path) 101 | 102 | 103 | def make_picoLlaVa(path: str): 104 | # Define minimal vision configuration 105 | vision_config = CLIPVisionConfig( 106 | image_size=32, 107 | patch_size=4, 108 | num_hidden_layers=2, 109 | num_attention_heads=2, 110 | hidden_size=64, 111 | intermediate_size=128, 112 | ) 113 | 114 | # Define minimal text configuration 115 | text_config = LlamaConfig( 116 | vocab_size=64, 117 | hidden_size=32, 118 | intermediate_size=48, 119 | num_attention_heads=16, 120 | num_hidden_layers=2, 121 | ) 122 | 123 | # Combine into Llava configuration 124 | llava_config = LlavaConfig( 125 | vision_config=vision_config, 126 | text_config=text_config, 127 | image_seq_length=16, 128 | ) 129 | 130 | # Instantiate the model 131 | model = LlavaForConditionalGeneration(config=llava_config) 132 | model.save_pretrained(path, safe_serialization=True) 133 | return str(path) 134 | -------------------------------------------------------------------------------- /tests/test_chat_template.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | from transformers import AutoTokenizer 5 | 6 | from mergekit.config import InputModelDefinition, MergeConfiguration 7 | from tests.common import make_picollama, run_and_check_merge 8 | from tests.test_tokenizer import make_tokenizer 9 | 10 | 11 | @pytest.fixture(scope="session") 12 | def model_base(tmp_path_factory): 13 | model_path = make_picollama(tmp_path_factory.mktemp("model_base"), vocab_size=64) 14 | make_tokenizer(vocab_size=64, added_tokens=[]).save_pretrained(model_path) 15 | return model_path 16 | 17 | 18 | @pytest.fixture(scope="session") 19 | def model_b(tmp_path_factory): 20 | return make_picollama(tmp_path_factory.mktemp("model_b")) 21 | 22 | 23 | def check_chat_template(model_path: str, needle: Optional[str] = None): 24 | tokenizer = AutoTokenizer.from_pretrained(model_path) 25 | if needle is None: 26 | assert not tokenizer.chat_template, "Expected no chat template" 27 | return 28 | assert ( 29 | tokenizer.chat_template and needle in tokenizer.chat_template 30 | ), f"Expected chat template to contain {needle}" 31 | 32 | 33 | class TestChatTemplate: 34 | def test_template_chatml(self, model_base, model_b): 35 | config = MergeConfiguration( 36 | merge_method="linear", 37 | models=[ 38 | InputModelDefinition(model=model_base, parameters={"weight": 0.5}), 39 | InputModelDefinition(model=model_b, parameters={"weight": 0.5}), 40 | ], 41 | base_model=model_base, 42 | dtype="bfloat16", 43 | chat_template="chatml", 44 | ) 45 | run_and_check_merge( 46 | config, 47 | validate=lambda p: check_chat_template(p, "<|im_start|>"), 48 | ) 49 | 50 | def test_template_literal_jinja(self, model_base, model_b): 51 | config = MergeConfiguration( 52 | merge_method="linear", 53 | models=[ 54 | InputModelDefinition(model=model_base, parameters={"weight": 0.5}), 55 | InputModelDefinition(model=model_b, parameters={"weight": 0.5}), 56 | ], 57 | base_model=model_base, 58 | dtype="bfloat16", 59 | chat_template="{{messages[0]['content']}}", 60 | ) 61 | run_and_check_merge( 62 | config, 63 | validate=lambda p: check_chat_template(p, "{{messages[0]['content']}}"), 64 | ) 65 | -------------------------------------------------------------------------------- /tests/test_io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import torch 5 | 6 | from mergekit.io import TensorWriter 7 | 8 | 9 | class TestTensorWriter: 10 | def test_safetensors(self): 11 | with tempfile.TemporaryDirectory() as d: 12 | writer = TensorWriter(d, safe_serialization=True) 13 | writer.save_tensor("steve", torch.randn(4)) 14 | writer.finalize() 15 | 16 | assert os.path.exists(os.path.join(d, "model.safetensors")) 17 | 18 | def test_pickle(self): 19 | with tempfile.TemporaryDirectory() as d: 20 | writer = TensorWriter(d, safe_serialization=False) 21 | writer.save_tensor("timothan", torch.randn(4)) 22 | writer.finalize() 23 | 24 | assert os.path.exists(os.path.join(d, "pytorch_model.bin")) 25 | 26 | def test_duplicate_tensor(self): 27 | with tempfile.TemporaryDirectory() as d: 28 | writer = TensorWriter(d, safe_serialization=True) 29 | jim = torch.randn(4) 30 | writer.save_tensor("jim", jim) 31 | writer.save_tensor("jimbo", jim) 32 | writer.finalize() 33 | 34 | assert os.path.exists(os.path.join(d, "model.safetensors")) 35 | -------------------------------------------------------------------------------- /tests/test_lazy_unpickle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mergekit.io import LazyTensorLoader 4 | 5 | 6 | class TestLazyUnpickle: 7 | def test_lazy_unpickle(self, tmp_path): 8 | data = { 9 | "a": torch.tensor([1, 2, 3]), 10 | "b": torch.tensor([4, 5, 6]), 11 | } 12 | path = tmp_path / "pytorch_model.bin" 13 | torch.save(data, path) 14 | loader = LazyTensorLoader.from_disk(tmp_path) 15 | for name in data: 16 | assert name in loader.index.tensor_paths 17 | tensor = loader.get_tensor(name) 18 | assert torch.equal(tensor, data[name]) 19 | -------------------------------------------------------------------------------- /tests/test_modelref.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from mergekit.common import ModelPath, ModelReference 4 | 5 | 6 | class TestModelReference: 7 | def test_parse_simple(self): 8 | text = "hf_user/model" 9 | mr = ModelReference.parse(text) 10 | assert mr.model == ModelPath(path="hf_user/model", revision=None) 11 | assert mr.lora is None 12 | assert str(mr) == text 13 | 14 | def test_parse_lora(self): 15 | text = "hf_user/model+hf_user/lora" 16 | mr = ModelReference.parse(text) 17 | assert mr.model == ModelPath(path="hf_user/model", revision=None) 18 | assert mr.lora == ModelPath(path="hf_user/lora", revision=None) 19 | assert str(mr) == text 20 | 21 | def test_parse_revision(self): 22 | text = "hf_user/model@v0.0.1" 23 | mr = ModelReference.parse(text) 24 | assert mr.model == ModelPath(path="hf_user/model", revision="v0.0.1") 25 | assert mr.lora is None 26 | assert str(mr) == text 27 | 28 | def test_parse_lora_plus_revision(self): 29 | text = "hf_user/model@v0.0.1+hf_user/lora@main" 30 | mr = ModelReference.parse(text) 31 | assert mr.model == ModelPath(path="hf_user/model", revision="v0.0.1") 32 | assert mr.lora == ModelPath(path="hf_user/lora", revision="main") 33 | assert str(mr) == text 34 | 35 | def test_parse_bad(self): 36 | with pytest.raises(RuntimeError): 37 | ModelReference.parse("@@@@@") 38 | 39 | with pytest.raises(RuntimeError): 40 | ModelReference.parse("a+b+c") 41 | 42 | with pytest.raises(RuntimeError): 43 | ModelReference.parse("a+b+c@d+e@f@g") 44 | -------------------------------------------------------------------------------- /tests/test_sparsify.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from mergekit.sparsify import SparsificationMethod, sparsify 5 | 6 | 7 | @pytest.fixture 8 | def sample_tensor(): 9 | res = torch.randn(128, 64) 10 | res[res == 0] = 7 # very low chance, but hey! 11 | return res 12 | 13 | 14 | class TestMagnitude: 15 | def test_full_density(self, sample_tensor): 16 | assert torch.equal( 17 | sparsify(sample_tensor, density=1, method=SparsificationMethod.magnitude), 18 | sample_tensor, 19 | ) 20 | 21 | def test_zero_density(self, sample_tensor): 22 | with pytest.raises(AssertionError): 23 | sparsify(sample_tensor, density=0, method=SparsificationMethod.magnitude) 24 | 25 | def test_partial_density(self, sample_tensor): 26 | result = sparsify( 27 | sample_tensor, density=0.5, method=SparsificationMethod.magnitude 28 | ) 29 | assert torch.count_nonzero(result) == sample_tensor.view(-1).shape[0] // 2 30 | 31 | def test_outliers(self, sample_tensor): 32 | for gamma_0 in [0.1, 0.2, 0.5, 1.0]: 33 | for density in [0.1, 0.3, 0.5, 0.6, 0.9, 1.0]: 34 | sparsity = 1 - density 35 | gamma = gamma_0 * sparsity 36 | result = sparsify( 37 | sample_tensor, 38 | density=density, 39 | method=SparsificationMethod.magnitude_outliers, 40 | gamma=gamma, 41 | ) 42 | assert torch.count_nonzero(result) == int( 43 | sample_tensor.view(-1).shape[0] * density 44 | ) 45 | 46 | def test_norm_rescale(self, sample_tensor): 47 | l1_norm = sample_tensor.abs().sum() 48 | l2_norm = sample_tensor.norm() 49 | linf_norm = sample_tensor.abs().max() 50 | 51 | normed_l1 = sparsify( 52 | sample_tensor, 53 | density=0.5, 54 | method=SparsificationMethod.magnitude, 55 | rescale_norm="l1", 56 | ) 57 | normed_l2 = sparsify( 58 | sample_tensor, 59 | density=0.5, 60 | method=SparsificationMethod.magnitude, 61 | rescale_norm="l2", 62 | ) 63 | normed_linf = sparsify( 64 | sample_tensor, 65 | density=0.5, 66 | method=SparsificationMethod.magnitude, 67 | rescale_norm="linf", 68 | ) 69 | 70 | assert torch.isclose(normed_l1.abs().sum(), l1_norm, rtol=0.01) 71 | assert torch.isclose(normed_l2.norm(), l2_norm, rtol=0.01) 72 | assert torch.isclose(normed_linf.abs().max(), linf_norm, rtol=0.01) 73 | 74 | def test_della_magprune(self, sample_tensor): 75 | res = sparsify( 76 | sample_tensor, 77 | density=0.5, 78 | method=SparsificationMethod.della_magprune, 79 | epsilon=0.05, 80 | rescale_norm="l1", 81 | ) 82 | assert not res.isnan().any(), "NaNs in result tensor" 83 | assert not res.isinf().any(), "Infs in result tensor" 84 | 85 | 86 | class TestBernoulli: 87 | NUM_ITERATIONS = 1000 88 | 89 | def test_bernoulli_with_rescale(self, sample_tensor): 90 | ref_abs_sum = sample_tensor.abs().sum() 91 | avg_abs_sum = torch.zeros_like(ref_abs_sum) 92 | for _ in range(TestBernoulli.NUM_ITERATIONS): 93 | rescaled = sparsify( 94 | sample_tensor, 95 | density=0.5, 96 | method=SparsificationMethod.random, 97 | rescale_norm="l1", 98 | ) 99 | avg_abs_sum += rescaled.abs().sum() 100 | avg_abs_sum /= TestBernoulli.NUM_ITERATIONS 101 | 102 | assert torch.isclose(avg_abs_sum, ref_abs_sum, rtol=0.01) 103 | 104 | def test_bernoulli_without_rescale(self, sample_tensor): 105 | result = sparsify( 106 | sample_tensor, 107 | density=0.5, 108 | method=SparsificationMethod.random, 109 | rescale_norm=None, 110 | ) 111 | assert 0 < torch.count_nonzero(result) <= sample_tensor.view(-1).shape[0] 112 | 113 | def test_cpu_dtypes(self, sample_tensor): 114 | for dt in (torch.float16, torch.bfloat16, torch.float32): 115 | sparsify( 116 | tensor=sample_tensor.to(dtype=dt).cpu(), 117 | density=0.5, 118 | method=SparsificationMethod.random, 119 | rescale_norm="l1", 120 | ) 121 | --------------------------------------------------------------------------------