├── .gitignore ├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── examples.py ├── hypercloning ├── __init__.py ├── common.py ├── gemma_cloning.py ├── llama_cloning.py ├── olmo_cloning.py ├── opt_cloning.py └── pythia_cloning.py ├── images └── teaser.png └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | ### macOS ### 2 | # General 3 | .DS_Store 4 | .AppleDouble 5 | .LSOverride 6 | 7 | # Icon must end with two \r 8 | Icon 9 | 10 | 11 | # Thumbnails 12 | ._* 13 | 14 | # Files that might appear in the root of a volume 15 | .DocumentRevisions-V100 16 | .fseventsd 17 | .Spotlight-V100 18 | .TemporaryItems 19 | .Trashes 20 | .VolumeIcon.icns 21 | .com.apple.timemachine.donotpresent 22 | 23 | # Directories potentially created on remote AFP share 24 | .AppleDB 25 | .AppleDesktop 26 | Network Trash Folder 27 | Temporary Items 28 | .apdisk 29 | 30 | ### macOS Patch ### 31 | # iCloud generated files 32 | *.icloud 33 | 34 | ### Python ### 35 | # Byte-compiled / optimized / DLL files 36 | __pycache__/ 37 | *.py[cod] 38 | *$py.class 39 | 40 | # C extensions 41 | *.so 42 | 43 | # Distribution / packaging 44 | .Python 45 | build/ 46 | develop-eggs/ 47 | dist/ 48 | downloads/ 49 | eggs/ 50 | .eggs/ 51 | lib/ 52 | lib64/ 53 | parts/ 54 | sdist/ 55 | var/ 56 | wheels/ 57 | share/python-wheels/ 58 | *.egg-info/ 59 | .installed.cfg 60 | *.egg 61 | MANIFEST 62 | 63 | # PyInstaller 64 | # Usually these files are written by a python script from a template 65 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 66 | *.manifest 67 | *.spec 68 | 69 | # Installer logs 70 | pip-log.txt 71 | pip-delete-this-directory.txt 72 | 73 | # Unit test / coverage reports 74 | htmlcov/ 75 | .tox/ 76 | .nox/ 77 | .coverage 78 | .coverage.* 79 | .cache 80 | nosetests.xml 81 | coverage.xml 82 | *.cover 83 | *.py,cover 84 | .hypothesis/ 85 | .pytest_cache/ 86 | cover/ 87 | 88 | # Translations 89 | *.mo 90 | *.pot 91 | 92 | # Django stuff: 93 | *.log 94 | local_settings.py 95 | db.sqlite3 96 | db.sqlite3-journal 97 | 98 | # Flask stuff: 99 | instance/ 100 | .webassets-cache 101 | 102 | # Scrapy stuff: 103 | .scrapy 104 | 105 | # Sphinx documentation 106 | docs/_build/ 107 | 108 | # PyBuilder 109 | .pybuilder/ 110 | target/ 111 | 112 | # Jupyter Notebook 113 | .ipynb_checkpoints 114 | 115 | # IPython 116 | profile_default/ 117 | ipython_config.py 118 | 119 | # pyenv 120 | # For a library or package, you might want to ignore these files since the code is 121 | # intended to run in multiple environments; otherwise, check them in: 122 | # .python-version 123 | 124 | # pipenv 125 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 126 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 127 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 128 | # install all needed dependencies. 129 | #Pipfile.lock 130 | 131 | # poetry 132 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 133 | # This is especially recommended for binary packages to ensure reproducibility, and is more 134 | # commonly ignored for libraries. 135 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 136 | #poetry.lock 137 | 138 | # pdm 139 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 140 | #pdm.lock 141 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 142 | # in version control. 143 | # https://pdm.fming.dev/#use-with-ide 144 | .pdm.toml 145 | 146 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 147 | __pypackages__/ 148 | 149 | # Celery stuff 150 | celerybeat-schedule 151 | celerybeat.pid 152 | 153 | # SageMath parsed files 154 | *.sage.py 155 | 156 | # Environments 157 | .env 158 | .venv 159 | env/ 160 | venv/ 161 | ENV/ 162 | env.bak/ 163 | venv.bak/ 164 | 165 | # Spyder project settings 166 | .spyderproject 167 | .spyproject 168 | 169 | # Rope project settings 170 | .ropeproject 171 | 172 | # mkdocs documentation 173 | /site 174 | 175 | # mypy 176 | .mypy_cache/ 177 | .dmypy.json 178 | dmypy.json 179 | 180 | # Pyre type checker 181 | .pyre/ 182 | 183 | # pytype static type analyzer 184 | .pytype/ 185 | 186 | # Cython debug symbols 187 | cython_debug/ 188 | 189 | 190 | ### Python Patch ### 191 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 192 | poetry.toml 193 | 194 | # ruff 195 | .ruff_cache/ 196 | 197 | # LSP config files 198 | pyrightconfig.json 199 | -------------------------------------------------------------------------------- /ACKNOWLEDGEMENTS: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this FoundationDB Software may utilize the Hugging Face transformers library to import pre-trained models. The use of this library is hereby acknowledged. The license of this library (as of October 10 2024) is copied below. 3 | 4 | _____________________ 5 | 6 | Copyright 2018- The Hugging Face team. All rights reserved. 7 | 8 | Apache License 9 | Version 2.0, January 2004 10 | http://www.apache.org/licenses/ 11 | 12 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 13 | 14 | 1. Definitions. 15 | 16 | "License" shall mean the terms and conditions for use, reproduction, 17 | and distribution as defined by Sections 1 through 9 of this document. 18 | 19 | "Licensor" shall mean the copyright owner or entity authorized by 20 | the copyright owner that is granting the License. 21 | 22 | "Legal Entity" shall mean the union of the acting entity and all 23 | other entities that control, are controlled by, or are under common 24 | control with that entity. For the purposes of this definition, 25 | "control" means (i) the power, direct or indirect, to cause the 26 | direction or management of such entity, whether by contract or 27 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 28 | outstanding shares, or (iii) beneficial ownership of such entity. 29 | 30 | "You" (or "Your") shall mean an individual or Legal Entity 31 | exercising permissions granted by this License. 32 | 33 | "Source" form shall mean the preferred form for making modifications, 34 | including but not limited to software source code, documentation 35 | source, and configuration files. 36 | 37 | "Object" form shall mean any form resulting from mechanical 38 | transformation or translation of a Source form, including but 39 | not limited to compiled object code, generated documentation, 40 | and conversions to other media types. 41 | 42 | "Work" shall mean the work of authorship, whether in Source or 43 | Object form, made available under the License, as indicated by a 44 | copyright notice that is included in or attached to the work 45 | (an example is provided in the Appendix below). 46 | 47 | "Derivative Works" shall mean any work, whether in Source or Object 48 | form, that is based on (or derived from) the Work and for which the 49 | editorial revisions, annotations, elaborations, or other modifications 50 | represent, as a whole, an original work of authorship. For the purposes 51 | of this License, Derivative Works shall not include works that remain 52 | separable from, or merely link (or bind by name) to the interfaces of, 53 | the Work and Derivative Works thereof. 54 | 55 | "Contribution" shall mean any work of authorship, including 56 | the original version of the Work and any modifications or additions 57 | to that Work or Derivative Works thereof, that is intentionally 58 | submitted to Licensor for inclusion in the Work by the copyright owner 59 | or by an individual or Legal Entity authorized to submit on behalf of 60 | the copyright owner. For the purposes of this definition, "submitted" 61 | means any form of electronic, verbal, or written communication sent 62 | to the Licensor or its representatives, including but not limited to 63 | communication on electronic mailing lists, source code control systems, 64 | and issue tracking systems that are managed by, or on behalf of, the 65 | Licensor for the purpose of discussing and improving the Work, but 66 | excluding communication that is conspicuously marked or otherwise 67 | designated in writing by the copyright owner as "Not a Contribution." 68 | 69 | "Contributor" shall mean Licensor and any individual or Legal Entity 70 | on behalf of whom a Contribution has been received by Licensor and 71 | subsequently incorporated within the Work. 72 | 73 | 2. Grant of Copyright License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | copyright license to reproduce, prepare Derivative Works of, 77 | publicly display, publicly perform, sublicense, and distribute the 78 | Work and such Derivative Works in Source or Object form. 79 | 80 | 3. Grant of Patent License. Subject to the terms and conditions of 81 | this License, each Contributor hereby grants to You a perpetual, 82 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 83 | (except as stated in this section) patent license to make, have made, 84 | use, offer to sell, sell, import, and otherwise transfer the Work, 85 | where such license applies only to those patent claims licensable 86 | by such Contributor that are necessarily infringed by their 87 | Contribution(s) alone or by combination of their Contribution(s) 88 | with the Work to which such Contribution(s) was submitted. If You 89 | institute patent litigation against any entity (including a 90 | cross-claim or counterclaim in a lawsuit) alleging that the Work 91 | or a Contribution incorporated within the Work constitutes direct 92 | or contributory patent infringement, then any patent licenses 93 | granted to You under this License for that Work shall terminate 94 | as of the date such litigation is filed. 95 | 96 | 4. Redistribution. You may reproduce and distribute copies of the 97 | Work or Derivative Works thereof in any medium, with or without 98 | modifications, and in Source or Object form, provided that You 99 | meet the following conditions: 100 | 101 | (a) You must give any other recipients of the Work or 102 | Derivative Works a copy of this License; and 103 | 104 | (b) You must cause any modified files to carry prominent notices 105 | stating that You changed the files; and 106 | 107 | (c) You must retain, in the Source form of any Derivative Works 108 | that You distribute, all copyright, patent, trademark, and 109 | attribution notices from the Source form of the Work, 110 | excluding those notices that do not pertain to any part of 111 | the Derivative Works; and 112 | 113 | (d) If the Work includes a "NOTICE" text file as part of its 114 | distribution, then any Derivative Works that You distribute must 115 | include a readable copy of the attribution notices contained 116 | within such NOTICE file, excluding those notices that do not 117 | pertain to any part of the Derivative Works, in at least one 118 | of the following places: within a NOTICE text file distributed 119 | as part of the Derivative Works; within the Source form or 120 | documentation, if provided along with the Derivative Works; or, 121 | within a display generated by the Derivative Works, if and 122 | wherever such third-party notices normally appear. The contents 123 | of the NOTICE file are for informational purposes only and 124 | do not modify the License. You may add Your own attribution 125 | notices within Derivative Works that You distribute, alongside 126 | or as an addendum to the NOTICE text from the Work, provided 127 | that such additional attribution notices cannot be construed 128 | as modifying the License. 129 | 130 | You may add Your own copyright statement to Your modifications and 131 | may provide additional or different license terms and conditions 132 | for use, reproduction, or distribution of Your modifications, or 133 | for any such Derivative Works as a whole, provided Your use, 134 | reproduction, and distribution of the Work otherwise complies with 135 | the conditions stated in this License. 136 | 137 | 5. Submission of Contributions. Unless You explicitly state otherwise, 138 | any Contribution intentionally submitted for inclusion in the Work 139 | by You to the Licensor shall be under the terms and conditions of 140 | this License, without any additional terms or conditions. 141 | Notwithstanding the above, nothing herein shall supersede or modify 142 | the terms of any separate license agreement you may have executed 143 | with Licensor regarding such Contributions. 144 | 145 | 6. Trademarks. This License does not grant permission to use the trade 146 | names, trademarks, service marks, or product names of the Licensor, 147 | except as required for reasonable and customary use in describing the 148 | origin of the Work and reproducing the content of the NOTICE file. 149 | 150 | 7. Disclaimer of Warranty. Unless required by applicable law or 151 | agreed to in writing, Licensor provides the Work (and each 152 | Contributor provides its Contributions) on an "AS IS" BASIS, 153 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 154 | implied, including, without limitation, any warranties or conditions 155 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 156 | PARTICULAR PURPOSE. You are solely responsible for determining the 157 | appropriateness of using or redistributing the Work and assume any 158 | risks associated with Your exercise of permissions under this License. 159 | 160 | 8. Limitation of Liability. In no event and under no legal theory, 161 | whether in tort (including negligence), contract, or otherwise, 162 | unless required by applicable law (such as deliberate and grossly 163 | negligent acts) or agreed to in writing, shall any Contributor be 164 | liable to You for damages, including any direct, indirect, special, 165 | incidental, or consequential damages of any character arising as a 166 | result of this License or out of the use or inability to use the 167 | Work (including but not limited to damages for loss of goodwill, 168 | work stoppage, computer failure or malfunction, or any and all 169 | other commercial damages or losses), even if such Contributor 170 | has been advised of the possibility of such damages. 171 | 172 | 9. Accepting Warranty or Additional Liability. While redistributing 173 | the Work or Derivative Works thereof, You may choose to offer, 174 | and charge a fee for, acceptance of support, warranty, indemnity, 175 | or other liability obligations and/or rights consistent with this 176 | License. However, in accepting such obligations, You may act only 177 | on Your own behalf and on Your sole responsibility, not on behalf 178 | of any other Contributor, and only if You agree to indemnify, 179 | defend, and hold each Contributor harmless for any liability 180 | incurred by, or claims asserted against, such Contributor by reason 181 | of your accepting any such warranty or additional liability. 182 | 183 | END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE.txt). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). 12 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ------------------------------------------------------------------------------- 42 | SOFTWARE DISTRIBUTED WITH ML-HYPERCLONING: 43 | 44 | The ml-hypercloning software includes a number of subcomponents with separate 45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 46 | ------------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HyperCloning 2 | 3 | This software project accompanies the research paper, [Scaling Smart: Accelerating Large Language Model Pre-training with Small Model Initialization](https://arxiv.org/abs/2409.12903). 4 | 5 | HyperCloning can pass the knowledge from a small pre-trained LLM to a large LLM. The larger LLM can undergo fine-tuning to get improved accuracy. 6 | 7 | ![Illustration of HyperCloning for Linear Layers](images/teaser.png) 8 | 9 | ## Installation 10 | 11 | `pip install -r requirements.txt` 12 | 13 | ## Sample Code 14 | 15 | The following snippet shows how to clone a source model into a destination model: 16 | 17 | ``` 18 | from transformers import AutoModelForCausalLM 19 | from hypercloning import cloneModel 20 | 21 | # instantiate the source model (pretrained): 22 | source_model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") 23 | 24 | # Clone a model with 2x embedding size and 2x FFN dimension: 25 | destination_model = cloneModel(source_model, embedding_dim_multiplier=2, up_project_multiplier=2) 26 | ``` 27 | 28 | You may modify and run the following to perform cloning on supported models: 29 | 30 | ``` 31 | python examples.py 32 | ``` 33 | ## Supported Models 34 | The following families of models are currently supported: 35 | - [OPT](https://huggingface.co/docs/transformers/en/model_doc/opt) 36 | - [Pythia](https://huggingface.co/models?other=pythia) 37 | - [OLMo](https://huggingface.co/docs/transformers/en/model_doc/olmo) 38 | - [Gemma](https://huggingface.co/docs/transformers/en/model_doc/gemma) 39 | - [Llama](https://huggingface.co/docs/transformers/en/model_doc/llama2) 40 | 41 | ## Limitations 42 | - The current implementation requires `embedding_dim_multiplier` and `up_project_multiplier` to be integers. fractional values are not supported. 43 | - Although the destination network's output is valid, it may not be perfectly aligned with the source network. This can happen due to numerical precision issues. 44 | - For Attention Layers, we suggest only changing the number of attention heads without changing the head_size for each head. Changing the head_size would make the code more complicated. 45 | 46 | ## References 47 | 48 | For citations, you may use the following: 49 | ``` 50 | @article{samragh2024scaling, 51 | title={Scaling Smart: Accelerating Large Language Model Pre-training with Small Model Initialization}, 52 | author={Samragh, Mohammad and Mirzadeh, Iman and Vahid, Keivan Alizadeh and Faghri, Fartash and Cho, Minsik and Nabi, Moin and Naik, Devang and Farajtabar, Mehrdad}, 53 | journal={arXiv preprint arXiv:2409.12903}, 54 | year={2024} 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /examples.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | 8 | import torch 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | 11 | from hypercloning import cloneModel 12 | 13 | hf_token = os.environ["HF_TOKEN"] 14 | example_models = [ 15 | "meta-llama/Llama-3.2-1B", 16 | "google/gemma-2-2b-it", 17 | "meta-llama/Llama-2-7b-hf", 18 | "allenai/OLMO-1b", 19 | "google/gemma-2b", 20 | "EleutherAI/pythia-410m-deduped", 21 | "facebook/opt-350m", 22 | ] 23 | 24 | 25 | def main(): 26 | query = "The capital of France is" 27 | device = "cpu" 28 | for hf_path in example_models: 29 | print(f"\n\n\n##########################{hf_path}##########################") 30 | src_network = AutoModelForCausalLM.from_pretrained( 31 | hf_path, trust_remote_code=True, torch_dtype=torch.float32, token=hf_token 32 | ) 33 | print(f"******** source network:\n{src_network}\n") 34 | dst_network = cloneModel( 35 | src_network, embedding_dim_multiplier=2, up_project_multiplier=2 36 | ) 37 | print(f"******** target network:\n{dst_network}\n") 38 | src_network.to(device) 39 | dst_network.to(device) 40 | tokenizer = AutoTokenizer.from_pretrained(hf_path, token=hf_token) 41 | inputs = tokenizer(query, return_tensors="pt") 42 | outputs = src_network.generate( 43 | inputs["input_ids"].to(device), 44 | max_length=50, 45 | num_return_sequences=1, 46 | do_sample=False, 47 | ) 48 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) 49 | outputs = dst_network.generate( 50 | inputs["input_ids"].to(device), 51 | max_length=50, 52 | num_return_sequences=1, 53 | do_sample=False, 54 | ) 55 | generated_text2 = tokenizer.decode(outputs[0], skip_special_tokens=True) 56 | print( 57 | f"******** generated by source network: \n {generated_text}\n ******** \ 58 | generated by target network: \n {generated_text2}" 59 | ) 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /hypercloning/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2020 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from hf_olmo.configuration_olmo import OLMoConfig 7 | from transformers import (Gemma2Config, GemmaConfig, GPTNeoXConfig, 8 | LlamaConfig, OPTConfig) 9 | 10 | from hypercloning.gemma_cloning import clone_gemma, clone_gemma2 11 | from hypercloning.llama_cloning import clone_llama 12 | from hypercloning.olmo_cloning import clone_olmo 13 | from hypercloning.opt_cloning import clone_opt 14 | from hypercloning.pythia_cloning import clone_pythia 15 | 16 | REGISTERED_CLONING_FUNCTIONS = { 17 | "LlamaConfig": clone_llama, 18 | "GemmaConfig": clone_gemma, 19 | "Gemma2Config": clone_gemma2, 20 | "OPTConfig": clone_opt, 21 | "OLMoConfig": clone_olmo, 22 | "GPTNeoXConfig": clone_pythia, 23 | } 24 | 25 | 26 | def cloneModel( 27 | model, embedding_dim_multiplier: int, up_project_multiplier: int, **kwargs 28 | ): 29 | """ 30 | Expand 'model' according to 'embedding_dim_multiplier' and 31 | 'up_project_multiplier'. 32 | 33 | Arguments: 34 | embedding_dim_multiplier: 35 | Expansion factor for embedding size. 36 | up_project_multiplier: 37 | Expansion factor for the FFN layers. 38 | kwargs can include: 39 | snr_db: 40 | Signal to noise ratio in decibels if noise is desired to be 41 | added to the weight tensors. Defaults to None. 42 | up_project_multiplier: 43 | The ratio of the number of heads in the destination network 44 | divided by the number of heads in the source network. 45 | Defaults to 'embedding_dim_multiplier' (recommended). 46 | 47 | Returns: 48 | Cloned model with expanded parameters. 49 | """ 50 | cloning_function_key = str(type(model.config)).split(".")[-1][:-2].strip() 51 | 52 | assert ( 53 | cloning_function_key in REGISTERED_CLONING_FUNCTIONS 54 | ), f"cloning is not supported for model config of type {cloning_function_key}" 55 | cloning_function = REGISTERED_CLONING_FUNCTIONS[cloning_function_key] 56 | print(f"cloning the network using {cloning_function} ...") 57 | return cloning_function( 58 | model, 59 | embedding_dim_multiplier=embedding_dim_multiplier, 60 | up_project_multiplier=up_project_multiplier, 61 | **kwargs, 62 | ) 63 | -------------------------------------------------------------------------------- /hypercloning/common.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2020 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | 8 | import torch 9 | 10 | 11 | def scale_linear_layer(layer: torch.nn.Linear, scaler: float): 12 | """ 13 | Scales the parameters of 'layer' so that its output is multiplied by 'scaler'. 14 | 15 | Arguments: 16 | layer: 17 | Linear layer to be scaled. 18 | scaler: 19 | Value to multiply the layer output. 20 | 21 | Returns: 22 | None. 23 | """ 24 | layer.weight.data *= scaler 25 | if layer.bias is not None: 26 | layer.bias.data *= scaler 27 | 28 | 29 | def get_noise_with_snr(weight: torch.tensor, snr_db: float): 30 | """ 31 | Gaussian noise to be added to 'weight' so that the signal-to-noise 32 | ratio becomes 'snr_db'. 33 | 34 | Arguments: 35 | weight: 36 | Signal tensor. 37 | snr_db: 38 | Signal-to-noise ratio in decibels. 39 | 40 | Returns: 41 | Noise tensor. 42 | """ 43 | signal_power = torch.mean(weight**2) 44 | snr_linear = 10 ** (snr_db / 10) 45 | noise_power = signal_power / snr_linear 46 | noise = torch.randn_like(weight) 47 | current_noise_power = torch.mean(noise**2) 48 | noise = noise * torch.sqrt(noise_power / current_noise_power) 49 | return noise.to(weight.dtype) 50 | 51 | 52 | def add_noise(weight, block_shape, snr_db): 53 | """ 54 | Repeatedly adds and subtracts noise to 'block_shape' blocks within 'weight'. 55 | 56 | The noise is applied in alternating blocks of 'block_shape'. 57 | Below are several illustrations: 58 | 59 | Examples 1 & 2, even repetition of columns: 60 | +-------+-------+ +-------+-------+ 61 | | W | W | | W+N1 | W-N1 | 62 | +-------+-------+ --> +-------+-------+ 63 | | W | W | | W+N2 | W-N2 | 64 | +-------+-------+ +-------+-------+ 65 | 66 | +-------+-------+-------+-------+ +-------+-------+-------+-------+ 67 | | W | W | W | W | | W+N1 | W-N1 | W+N2 | W-N2 | 68 | +-------+-------+-------+-------+ --> +-------+-------+-------+-------+ 69 | | W | W | W | W | | W+N3 | W-N3 | W+N4 | W-N4 | 70 | +-------+-------+-------+-------+ +-------+-------+-------+-------+ 71 | 72 | Example 3, odd repetition of columns: 73 | +-------+-------+-------+ +-------+-------+-------+ 74 | | W | W | W | | W+N1 | W-N1 | W | 75 | +-------+-------+-------+ --> +-------+-------+-------+ 76 | | W | W | W | | W+N2 | W-N2 | W | 77 | +-------+-------+-------+ +-------+-------+-------+ 78 | 79 | Arguments: 80 | weight: 81 | Signal tensor. 82 | block_shape: 83 | Shape of the block to which noise is added or subtracted. 84 | snr_db: 85 | Signal-to-noise ratio in decibels. 86 | 87 | Returns: 88 | Noisy weight. 89 | """ 90 | assert weight.shape[0] % block_shape[0] == 0 91 | assert weight.shape[1] % block_shape[1] == 0 92 | n_repeat_0 = weight.shape[0] // block_shape[0] 93 | n_repeat_1 = weight.shape[1] // block_shape[1] 94 | if weight.ndim == 2: 95 | for n0 in range(n_repeat_0): 96 | start0 = n0 * block_shape[0] 97 | end0 = start0 + block_shape[0] 98 | for n1 in range(n_repeat_1 // 2): 99 | start1 = 2 * n1 * block_shape[1] 100 | end1 = start1 + block_shape[1] 101 | start2 = (2 * n1 + 1) * block_shape[1] 102 | end2 = start2 + block_shape[1] 103 | noise = get_noise_with_snr(weight[start0:end0, start1:end1], snr_db) 104 | weight[start0:end0, start1:end1] += noise 105 | weight[start0:end0, start2:end2] -= noise 106 | return weight 107 | else: 108 | for n0 in range(weight.shape[0]): 109 | weight[n0] = add_noise(weight[n0], block_shape[1:], snr_db) 110 | return weight 111 | 112 | 113 | def clone_matrix(dst_weight_shape, src_weight, snr_db=None, normalize=True): 114 | """ 115 | Clones a matrix from 'src_weight' into 'dst_weight_shape'. 116 | 117 | Arguments: 118 | dst_weight_shape: 119 | Shape of the destination matrix. Must divide 120 | src_weight.shape. 121 | src_weight: 122 | Source weight to be cloned. 123 | snr_db: 124 | Signal-to-noise ratio in case noise is to be added. 125 | Defaults to None (no noise added). 126 | normalize: 127 | If True, normalize the weight by the number of repetitions 128 | in the second dimension. 129 | 130 | Returns: 131 | Cloned matrix with shape 'dst_weight_shape'. 132 | """ 133 | out_features_old, in_features_old = src_weight.shape 134 | out_features_new, in_features_new = dst_weight_shape 135 | assert out_features_new >= out_features_old 136 | assert out_features_new % out_features_old == 0 137 | assert in_features_new >= in_features_old 138 | assert ( 139 | in_features_new % in_features_old == 0 140 | ), f"{in_features_new} does not divide {in_features_old}" 141 | n_repeat_0 = out_features_new // out_features_old 142 | n_repeat_1 = in_features_new // in_features_old 143 | 144 | dst_weight = src_weight.data.repeat(n_repeat_0, n_repeat_1) 145 | if normalize: 146 | dst_weight = dst_weight / n_repeat_1 147 | if snr_db is not None: 148 | dst_weight = add_noise(dst_weight, src_weight.shape, snr_db) 149 | return dst_weight 150 | 151 | 152 | def clone_vector(dst_vector_shape, src_vector): 153 | """ 154 | Clones a vector from 'src_vector' into 'dst_vector_shape'. 155 | 156 | Arguments: 157 | dst_vector_shape: 158 | Shape of the destination vector. Must divide src_vector.shape. 159 | src_vector: 160 | Source vector to be cloned. 161 | 162 | Returns: 163 | Cloned vector with shape 'dst_vector_shape'. 164 | """ 165 | assert src_vector.shape[0] <= dst_vector_shape[0] 166 | assert dst_vector_shape[0] % src_vector.shape[0] == 0 167 | n_repeat = dst_vector_shape[0] // src_vector.shape[0] 168 | dst_vector = src_vector.repeat(n_repeat) 169 | return dst_vector 170 | 171 | 172 | def clone_linear_layer(dst_layer, src_layer, snr_db=None): 173 | """ 174 | Clones linear layer parameters from 'src_layer' into 'dst_layer'. 175 | 176 | Arguments: 177 | dst_layer: 178 | Destination linear layer. 179 | src_layer: 180 | Source pretrained linear layer. 181 | snr_db: 182 | Optional signal-to-noise ratio in decibels to be added to the weight parameters of the destination layer. 183 | 184 | Returns: 185 | None. 186 | """ 187 | dst_layer.weight.data = clone_matrix( 188 | dst_layer.weight.shape, src_layer.weight.data, snr_db=snr_db 189 | ) 190 | if src_layer.bias is not None: 191 | assert ( 192 | dst_layer.bias is not None 193 | ), "source model has bias in its linear layers but destination model doesn't" 194 | dst_layer.bias.data = clone_vector(dst_layer.bias.shape, src_layer.bias.data) 195 | 196 | 197 | def clone_layer_norm(dst_layer, src_layer): 198 | """ 199 | Clones normalization layer parameters from 'src_layer' into 'dst_layer'. 200 | 201 | Arguments: 202 | dst_layer: 203 | Destination normalization layer. 204 | src_layer: 205 | Source pretrained normalization layer. 206 | 207 | Returns: 208 | None. 209 | """ 210 | if src_layer.weight is None and src_layer.bias is None: 211 | assert dst_layer.weight is None and dst_layer.bias is None 212 | return 213 | assert ( 214 | dst_layer.eps == src_layer.eps 215 | ), f"eps should be the same for source and destination layer-norms, \ 216 | got {src_layer.eps} and {dst_layer.eps}" 217 | assert ( 218 | dst_layer.elementwise_affine == src_layer.elementwise_affine 219 | ), f"elementwise_affine should be the same for source and destination \ 220 | layer-norms, got {src_layer.elementwise_affine} and {dst_layer.elementwise_affine}" 221 | dst_layer.weight.data = clone_vector(dst_layer.weight.shape, src_layer.weight) 222 | dst_layer.bias.data = clone_vector(dst_layer.bias.shape, src_layer.bias.data) 223 | 224 | 225 | def clone_rms_norm(dst_layer, src_layer): 226 | """ 227 | Clones rms-normalization layer parameters from 'src_layer' into 'dst_layer'. 228 | 229 | Arguments: 230 | dst_layer: 231 | Destination rms-normalization layer. 232 | src_layer: 233 | Source pretrained rms-normalization layer. 234 | 235 | Returns: 236 | None. 237 | """ 238 | dst_layer.weight.data = clone_vector(dst_layer.weight.shape, src_layer.weight) 239 | 240 | 241 | def rename_config( 242 | config, embedding_dim_multiplier: int = 1, up_project_multiplier: int = 1 243 | ): 244 | """ 245 | adjusts the model name according to 'embedding_dim_multiplier' and 'up_project_multiplier' 246 | Arguments: 247 | config: 248 | config to be modified. 249 | embedding_dim_multiplier: 250 | expansion ratio of embedding dimension. 251 | up_project_multiplier: 252 | expansion ratio of ffn layer. 253 | Returns: 254 | updated config. 255 | 256 | """ 257 | if embedding_dim_multiplier > 1: 258 | config._name_or_path += f"-{embedding_dim_multiplier}xembedding" 259 | if up_project_multiplier > 1: 260 | config._name_or_path += f"-{up_project_multiplier}xffn" 261 | return config 262 | 263 | 264 | class scaledLinear(torch.nn.Module): 265 | """ 266 | Wrapper layer that scales the weights of a linear layer before applying 267 | the linear transformation. This layer is useful in cases that embedding 268 | and unembedding layers are tied together, where the unembedding layer 269 | needs its weight to be scaled due to cloning but embedding layer should 270 | not scale the weights. 271 | Arguments: 272 | layer: 273 | original linear layer. 274 | scaler: 275 | scaler value. 276 | """ 277 | 278 | def __init__(self, layer, scaler): 279 | super().__init__() 280 | self.layer = layer 281 | self.scaler = scaler 282 | self.weight = self.layer.weight 283 | self.bias = self.layer.bias 284 | 285 | def forward(self, x): 286 | weight = self.layer.weight * self.scaler 287 | if self.layer.bias is not None: 288 | bias = self.layer.bias * self.scaler 289 | else: 290 | bias = None 291 | return torch.nn.functional.linear(x, weight, bias) 292 | -------------------------------------------------------------------------------- /hypercloning/gemma_cloning.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2020 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import copy 7 | 8 | import numpy as np 9 | import torch 10 | from transformers import Gemma2ForCausalLM, GemmaForCausalLM 11 | 12 | from hypercloning.common import (add_noise, clone_layer_norm, 13 | clone_linear_layer, clone_matrix, 14 | clone_rms_norm, rename_config, 15 | scale_linear_layer, scaledLinear) 16 | 17 | 18 | def clone_gemma_attention(dst_layer, src_layer, snr_db=None): 19 | """ 20 | Clones the attention layer from 'src_layer' into 'dst_layer'. 21 | 22 | Arguments: 23 | dst_layer: Destination attention layer. 24 | src_layer: Source (pretrained) attention layer. 25 | snr_db: signal to noise ratio. Defaults to None 26 | Returns: 27 | None. 28 | """ 29 | 30 | src_config = copy.deepcopy(src_layer.config) 31 | dst_config = copy.deepcopy(dst_layer.config) 32 | clone_gemma_qkv_layer( 33 | dst_layer.q_proj, 34 | src_layer.q_proj, 35 | dst_layer.num_heads, 36 | src_layer.num_heads, 37 | snr_db=snr_db, 38 | ) 39 | clone_gemma_qkv_layer( 40 | dst_layer.k_proj, 41 | src_layer.k_proj, 42 | dst_layer.num_key_value_heads, 43 | src_layer.num_key_value_heads, 44 | snr_db=snr_db, 45 | ) 46 | clone_gemma_qkv_layer( 47 | dst_layer.v_proj, 48 | src_layer.v_proj, 49 | dst_layer.num_key_value_heads, 50 | src_layer.num_key_value_heads, 51 | snr_db=snr_db, 52 | ) 53 | clone_linear_layer(dst_layer.o_proj, src_layer.o_proj, snr_db=snr_db) 54 | return dst_layer 55 | 56 | 57 | def clone_gemma_qkv_layer( 58 | dst_layer, src_layer, num_heads_dst, num_heads_src, snr_db=None 59 | ): 60 | """ 61 | Clones 'src_weight' into a weight tensor with 'dst_weight_shape' to be 62 | used in the attention layer. 63 | 64 | Arguments: 65 | dst_layer: 66 | Destination layer. 67 | src_layer: 68 | Source layer. 69 | num_heads_dst: 70 | Number of attention heads in the destination layer. 71 | num_heads_src: 72 | Number of attention heads in the source layer. 73 | snr_db: 74 | Signal-to-noise ratio. Defaults to None. 75 | 76 | Returns: 77 | None 78 | """ 79 | dst_layer.weight.data = clone_gemma_qkv_weight( 80 | dst_layer.weight.shape, 81 | src_layer.weight.data, 82 | num_heads_dst, 83 | num_heads_src, 84 | snr_db=snr_db, 85 | ) 86 | if src_layer.bias is not None: 87 | dst_layer.bias.data = clone_gemma_qkv_bias( 88 | dst_layer.bias.shape, 89 | src_layer.bias.data, 90 | num_heads_dst, 91 | num_heads_src, 92 | ) 93 | 94 | 95 | def clone_gemma_qkv_bias(dst_bias_shape, src_bias, num_heads_dst, num_heads_src): 96 | """ 97 | Clones 'src_bias' into a bias vector with 'dst_bias_shape' to be 98 | used in the attention layer. 99 | 100 | Arguments: 101 | dst_bias_shape: 102 | Shape of the bias tensor in the destination layer. 103 | src_bias: 104 | bias vector in the source layer. 105 | num_heads_dst: 106 | Number of attention heads in the destination layer. 107 | num_heads_src: 108 | Number of attention heads in the source layer. 109 | 110 | Returns: 111 | Cloned QKV bias. 112 | """ 113 | source_qkv_dim = src_bias.shape[0] 114 | destination_qkv_dim = dst_bias_shape[0] 115 | n_repeat = destination_qkv_dim // source_qkv_dim 116 | dst_bias = src_bias.reshape(num_heads_src, source_qkv_dim // num_heads_src) 117 | n_repeat_heads = num_heads_dst // num_heads_src 118 | n_repeat_head_dim = n_repeat // n_repeat_heads 119 | dst_bias = dst_bias.repeat(n_repeat_heads, n_repeat_head_dim) 120 | dst_bias = dst_bias.reshape(destination_qkv_dim) 121 | return dst_bias 122 | 123 | 124 | def clone_gemma_qkv_weight( 125 | dst_weight_shape, src_weight, num_heads_dst, num_heads_src, snr_db=None 126 | ): 127 | """ 128 | Clones 'src_weight' into a weight tensor with 'dst_weight_shape' to be 129 | used in the attention layer. 130 | 131 | Arguments: 132 | dst_weight_shape: 133 | Shape of the weight tensor in the destination layer. 134 | src_weight: 135 | Weight tensor in the source layer. 136 | num_heads_dst: 137 | Number of attention heads in the destination layer. 138 | num_heads_src: 139 | Number of attention heads in the source layer. 140 | snr_db: 141 | Signal-to-noise ratio. Defaults to None. 142 | 143 | Returns: 144 | Cloned QKV weights. 145 | """ 146 | 147 | source_embedding_dimension = src_weight.shape[1] 148 | destination_embedding_dimension = dst_weight_shape[1] 149 | source_qkv_dim = src_weight.shape[0] 150 | destination_qkv_dim = dst_weight_shape[0] 151 | n_repeat_in = destination_embedding_dimension // source_embedding_dimension 152 | n_repeat = destination_qkv_dim // source_qkv_dim 153 | dst_weight = src_weight.reshape( 154 | num_heads_src, source_qkv_dim // num_heads_src, source_embedding_dimension 155 | ) 156 | block_shape = dst_weight.shape 157 | n_repeat_heads = num_heads_dst // num_heads_src 158 | n_repeat_head_dim = n_repeat // n_repeat_heads 159 | dst_weight = ( 160 | dst_weight.repeat(n_repeat_heads, n_repeat_head_dim, n_repeat_in) / n_repeat_in 161 | ) 162 | if snr_db is not None: 163 | dst_weight = add_noise(dst_weight, block_shape, snr_db) 164 | dst_weight = dst_weight.reshape( 165 | destination_qkv_dim, destination_embedding_dimension 166 | ) # (d_head, n_heads, e) --> #(d_head*n_heads, e) 167 | 168 | return dst_weight 169 | 170 | 171 | def clone_gemma( 172 | src_network, 173 | embedding_dim_multiplier: int = 1, 174 | up_project_multiplier: int = 1, 175 | **kwargs, 176 | ): 177 | """ 178 | Cloning function for the Gemma family. 179 | 180 | For arguments description, refer to hypercloning.cloneModel. 181 | 182 | Returns: 183 | Cloned Gemma model instance. 184 | """ 185 | snr_db = kwargs.get("snr_db", None) 186 | num_heads_multiplier = kwargs.get("num_heads_multiplier", embedding_dim_multiplier) 187 | assert ( 188 | num_heads_multiplier == embedding_dim_multiplier 189 | ), "head_dim expansion is not supported for Gemma. The number of heads will \ 190 | be automatically computed based on embedding dimension expansion. Do not \ 191 | pass 'num_heads_multiplier' to 'clone_gemma'" 192 | 193 | # Set the destination network config according to user requested expansion factors: 194 | config = copy.deepcopy(src_network.config) 195 | config.hidden_size = embedding_dim_multiplier * config.hidden_size 196 | config.intermediate_size = up_project_multiplier * config.intermediate_size 197 | if config.num_key_value_heads != 1: 198 | config.num_key_value_heads = ( 199 | embedding_dim_multiplier * config.num_key_value_heads 200 | ) 201 | config.num_attention_heads = embedding_dim_multiplier * config.num_attention_heads 202 | 203 | # rename the config according to expansion factors 204 | config = rename_config(config, embedding_dim_multiplier, up_project_multiplier) 205 | 206 | # Make an instance of the destination network: 207 | dst_network = GemmaForCausalLM._from_config(config) 208 | 209 | # Note: Gemma multiplies the embedding tokens by sqrt(emb_dim). We should normalize by 210 | # 1/sqrt(embedding_dim_multiplier) to avoid a mismatch: 211 | 212 | dst_network.model.embed_tokens.weight.data = ( 213 | clone_matrix( 214 | dst_network.model.embed_tokens.weight.data.shape, 215 | src_network.model.embed_tokens.weight.data, 216 | normalize=False, 217 | ) 218 | * 1.0 219 | / np.sqrt(embedding_dim_multiplier) 220 | ) 221 | 222 | for dst_layer, src_layer in zip(dst_network.model.layers, src_network.model.layers): 223 | clone_rms_norm(dst_layer.input_layernorm, src_layer.input_layernorm) 224 | clone_rms_norm( 225 | dst_layer.post_attention_layernorm, src_layer.post_attention_layernorm 226 | ) 227 | dst_layer.self_attn = clone_gemma_attention( 228 | dst_layer.self_attn, src_layer.self_attn, snr_db=snr_db 229 | ) 230 | clone_linear_layer( 231 | dst_layer.mlp.gate_proj, src_layer.mlp.gate_proj, snr_db=snr_db 232 | ) 233 | clone_linear_layer(dst_layer.mlp.up_proj, src_layer.mlp.up_proj, snr_db=snr_db) 234 | clone_linear_layer( 235 | dst_layer.mlp.down_proj, src_layer.mlp.down_proj, snr_db=snr_db 236 | ) 237 | clone_rms_norm(dst_network.model.norm, src_network.model.norm) 238 | 239 | # Note: the unembedding layer is tied with the embedding layer. We need to divide the 240 | # weights of the unembedding layer by 'embedding_dim_multiplier' but the weights in the 241 | # embedding layer should not be divided. but note that the embedding weights were already 242 | # divided by 'sqrt(embedding_dim_multiplier)' for Embedding initialization at the begining 243 | # of this function. So we use a wrapper class around lm_head that divides the weights in 244 | # the unembedding forward function by 'sqrt(embedding_dim_multiplier)' one more time: 245 | 246 | if embedding_dim_multiplier > 1: 247 | dst_network.lm_head = scaledLinear( 248 | dst_network.lm_head, 1.0 / np.sqrt(embedding_dim_multiplier) 249 | ) 250 | return dst_network 251 | 252 | 253 | def clone_gemma2( 254 | src_network, 255 | embedding_dim_multiplier: int = 1, 256 | up_project_multiplier: int = 1, 257 | **kwargs, 258 | ): 259 | """ 260 | Cloning function for the Gemma2 family. 261 | 262 | For arguments description, refer to hypercloning.cloneModel. 263 | 264 | Returns: 265 | Cloned Gemma model instance. 266 | """ 267 | snr_db = kwargs.get("snr_db", None) 268 | num_heads_multiplier = kwargs.get("num_heads_multiplier", embedding_dim_multiplier) 269 | assert ( 270 | num_heads_multiplier == embedding_dim_multiplier 271 | ), "head_dim expansion is not supported for Gemma. The number of heads will \ 272 | be automatically computed based on embedding dimension expansion. Do not \ 273 | pass 'num_heads_multiplier' to 'clone_gemma2'" 274 | 275 | # Set the destination network config according to user requested expansion factors: 276 | config = copy.deepcopy(src_network.config) 277 | config.hidden_size = embedding_dim_multiplier * config.hidden_size 278 | config.intermediate_size = up_project_multiplier * config.intermediate_size 279 | if config.num_key_value_heads != 1: 280 | config.num_key_value_heads = ( 281 | embedding_dim_multiplier * config.num_key_value_heads 282 | ) 283 | config.num_attention_heads = embedding_dim_multiplier * config.num_attention_heads 284 | 285 | # rename the config according to expansion factors 286 | config = rename_config(config, embedding_dim_multiplier, up_project_multiplier) 287 | 288 | # Make an instance of the destination network: 289 | dst_network = Gemma2ForCausalLM._from_config(config) 290 | 291 | # Note: Gemma multiplies the embedding tokens by sqrt(emb_dim). We should normalize by 292 | # 1/sqrt(embedding_dim_multiplier) to avoid a mismatch: 293 | 294 | dst_network.model.embed_tokens.weight.data = ( 295 | clone_matrix( 296 | dst_network.model.embed_tokens.weight.data.shape, 297 | src_network.model.embed_tokens.weight.data, 298 | normalize=False, 299 | ) 300 | * 1.0 301 | / np.sqrt(embedding_dim_multiplier) 302 | ) 303 | 304 | for dst_layer, src_layer in zip(dst_network.model.layers, src_network.model.layers): 305 | clone_rms_norm(dst_layer.input_layernorm, src_layer.input_layernorm) 306 | clone_rms_norm( 307 | dst_layer.post_attention_layernorm, src_layer.post_attention_layernorm 308 | ) 309 | clone_rms_norm( 310 | dst_layer.pre_feedforward_layernorm, src_layer.pre_feedforward_layernorm 311 | ) 312 | clone_rms_norm( 313 | dst_layer.post_feedforward_layernorm, src_layer.post_feedforward_layernorm 314 | ) 315 | dst_layer.self_attn = clone_gemma_attention( 316 | dst_layer.self_attn, src_layer.self_attn, snr_db=snr_db 317 | ) 318 | clone_linear_layer( 319 | dst_layer.mlp.gate_proj, src_layer.mlp.gate_proj, snr_db=snr_db 320 | ) 321 | clone_linear_layer(dst_layer.mlp.up_proj, src_layer.mlp.up_proj, snr_db=snr_db) 322 | clone_linear_layer( 323 | dst_layer.mlp.down_proj, src_layer.mlp.down_proj, snr_db=snr_db 324 | ) 325 | clone_rms_norm(dst_network.model.norm, src_network.model.norm) 326 | 327 | # Note: the unembedding layer is tied with the embedding layer. We need to divide the 328 | # weights of the unembedding layer by 'embedding_dim_multiplier' but the weights in the 329 | # embedding layer should not be divided. but note that the embedding weights were already 330 | # divided by 'sqrt(embedding_dim_multiplier)' for Embedding initialization at the begining 331 | # of this function. So we use a wrapper class around lm_head that divides the weights in 332 | # the unembedding forward function by 'sqrt(embedding_dim_multiplier)' one more time: 333 | 334 | if embedding_dim_multiplier > 1: 335 | dst_network.lm_head = scaledLinear( 336 | dst_network.lm_head, 1.0 / np.sqrt(embedding_dim_multiplier) 337 | ) 338 | return dst_network 339 | -------------------------------------------------------------------------------- /hypercloning/llama_cloning.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2020 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import copy 7 | 8 | import numpy as np 9 | import torch 10 | from transformers import LlamaForCausalLM 11 | 12 | from hypercloning.common import (clone_layer_norm, clone_linear_layer, 13 | clone_matrix, clone_rms_norm, rename_config, 14 | scale_linear_layer, scaledLinear) 15 | from hypercloning.gemma_cloning import clone_gemma_attention 16 | 17 | 18 | def clone_llama( 19 | src_network, 20 | embedding_dim_multiplier: int = 1, 21 | up_project_multiplier: int = 1, 22 | **kwargs, 23 | ): 24 | """ 25 | Cloning function for the Llama family. 26 | 27 | For arguments description, refer to hypercloning.cloneModel. 28 | 29 | Returns: 30 | Cloned Llama model instance. 31 | """ 32 | snr_db = kwargs.get("snr_db", None) 33 | num_heads_multiplier = kwargs.get("num_heads_multiplier", embedding_dim_multiplier) 34 | assert ( 35 | num_heads_multiplier == embedding_dim_multiplier 36 | ), "head_dim expansion is not supported for Llama. The number of heads will \ 37 | be automatically computed based on embedding dimension expansion. Do not \ 38 | pass 'num_heads_multiplier' to 'clone_llama'" 39 | 40 | # Set the destination network config according to user requested expansion factors: 41 | config = copy.deepcopy(src_network.config) 42 | config.hidden_size = embedding_dim_multiplier * config.hidden_size 43 | config.intermediate_size = up_project_multiplier * config.intermediate_size 44 | if config.num_key_value_heads != 1: 45 | config.num_key_value_heads = ( 46 | embedding_dim_multiplier * config.num_key_value_heads 47 | ) 48 | config.num_attention_heads = embedding_dim_multiplier * config.num_attention_heads 49 | config.tie_word_embeddings = False 50 | # rename the config according to expansion factors 51 | config = rename_config(config, embedding_dim_multiplier, up_project_multiplier) 52 | 53 | # Make an instance of the destination network: 54 | dst_network = LlamaForCausalLM._from_config(config) 55 | 56 | dst_network.model.embed_tokens.weight.data = clone_matrix( 57 | dst_network.model.embed_tokens.weight.data.shape, 58 | src_network.model.embed_tokens.weight.data, 59 | normalize=False, 60 | ) 61 | 62 | for dst_layer, src_layer in zip(dst_network.model.layers, src_network.model.layers): 63 | clone_rms_norm(dst_layer.input_layernorm, src_layer.input_layernorm) 64 | clone_rms_norm( 65 | dst_layer.post_attention_layernorm, src_layer.post_attention_layernorm 66 | ) 67 | dst_layer.self_attn = clone_gemma_attention( 68 | dst_layer.self_attn, src_layer.self_attn, snr_db=snr_db 69 | ) 70 | clone_linear_layer( 71 | dst_layer.mlp.gate_proj, src_layer.mlp.gate_proj, snr_db=snr_db 72 | ) 73 | clone_linear_layer(dst_layer.mlp.up_proj, src_layer.mlp.up_proj, snr_db=snr_db) 74 | clone_linear_layer( 75 | dst_layer.mlp.down_proj, src_layer.mlp.down_proj, snr_db=snr_db 76 | ) 77 | clone_rms_norm(dst_network.model.norm, src_network.model.norm) 78 | clone_linear_layer(dst_network.lm_head, src_network.lm_head) 79 | return dst_network 80 | -------------------------------------------------------------------------------- /hypercloning/olmo_cloning.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2020 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import copy 7 | 8 | import numpy as np 9 | import torch 10 | from transformers import AutoModelForCausalLM 11 | 12 | from hypercloning.common import (add_noise, clone_layer_norm, 13 | clone_linear_layer, clone_matrix, 14 | rename_config, scale_linear_layer, 15 | scaledLinear) 16 | 17 | 18 | class clonedPositionalEmbedding(torch.nn.Module): 19 | """ 20 | Clones a source positional embedding layer called 'emb_module'. 21 | 22 | This function assumes 'emb_module' receives an input 'x' and computes P(x). 23 | The cloned module takes repeated inputs [x, ..., x]^T and computes the 24 | output as [P(x), ..., P(x)]^T. 25 | 26 | Arguments: 27 | emb_module: 28 | Original positional embedding module from the source network. 29 | repeat: 30 | Expansion factor indicating how many times the module should 31 | be called repeatedly. 32 | 33 | """ 34 | 35 | def __init__(self, emb_module, repeat): 36 | super().__init__() 37 | self.repeat = repeat 38 | self.module = emb_module 39 | 40 | def forward(self, q, k): 41 | """ 42 | shape of q: (...D) where D should be split. 43 | shape of k: (...D) where D should be split. 44 | """ 45 | assert ( 46 | q.shape[-1] % self.repeat 47 | ) == 0, ( 48 | f"q tensor dimension ({q.shape}) does not divide repeats ({self.repeat})" 49 | ) 50 | assert ( 51 | k.shape[-1] % self.repeat 52 | ) == 0, ( 53 | f"k tensor dimension ({k.shape}) does not divide repeats ({self.repeat})" 54 | ) 55 | q_split = torch.split(q, q.shape[-1] // self.repeat, dim=-1) 56 | k_split = torch.split(k, k.shape[-1] // self.repeat, dim=-1) 57 | outputs = [self.module(qq, kk) for qq, kk in zip(q_split, k_split)] 58 | qs = [o[0] for o in outputs] 59 | ks = [o[1] for o in outputs] 60 | return torch.cat(qs, dim=-1), torch.cat(ks, dim=-1) 61 | 62 | 63 | def reorder_swiglu(tensor, n_repeat): 64 | """ 65 | Reorders the rows in the first linear layer of the FFN to correct 66 | the gating that occurs in the subsequent SWIGLU activation function. 67 | 68 | The original linear layer produces the output [x_top, x_bottom]^T, 69 | and the following SWIGLU computes the activation as 'silu(x_top) * x_bottom'. 70 | 71 | In contrast, the cloned layer produces the output [x_up, x_down, x_up, x_down]^T. 72 | The SWIGLU incorrectly computes 'silu([x_top, x_bottom]^T) * [x_top, x_bottom]^T', 73 | which is not the desired behavior. 74 | 75 | To fix this, we need to reorder the weights of the cloned linear layer 76 | so that it produces [x_up, x_up, x_down, x_down]^T. This allows the SWIGLU 77 | to correctly compute 'silu([x_top, x_top]^T) * [x_bottom, x_bottom]^T'. 78 | 79 | Arguments: 80 | tensor: 81 | The weight or bias tensor in the destination FFN block. 82 | n_repeat: 83 | The expansion factor from the source to the destination FFN block. 84 | 85 | Returns: 86 | Reordered tensor. 87 | """ 88 | n = int(tensor.shape[0] // (n_repeat * 2)) 89 | tensors_up = [tensor[2 * i * n : (2 * i + 1) * n] for i in range(0, n_repeat)] 90 | tensors_down = [ 91 | tensor[(2 * i + 1) * n : (2 * i + 2) * n] for i in range(0, n_repeat) 92 | ] 93 | all_tensors = tensors_up + tensors_down 94 | return torch.cat(all_tensors, dim=0) 95 | 96 | 97 | def reorder_weights(w, n_heads_old, head_dim_old, n_repeat_dim, n_repeat_heads): 98 | """ 99 | Reorders the columns of the out_project linear layer at the end of the 100 | attention layer. 101 | 102 | This function is meant to preserve the functionality of the attention 103 | block when the head_dim is changed. 104 | 105 | Arguments: 106 | w: 107 | The weight tensor to be reordered (from the o_proj linear layer). 108 | n_heads_old: 109 | Number of heads in the source attention layer. 110 | head_dim_old: 111 | Dimension of each head in the source attention layer. 112 | n_repeat_dim: 113 | Number of times the head-dim of the source attention layer is 114 | repeated in the destination attention layer. 115 | n_repeat_heads: 116 | Number of times the heads of the source attention layer are 117 | repeated in the destination attention layer. 118 | 119 | Returns: 120 | Reordered weights. 121 | """ 122 | w = w.reshape(w.shape[0], n_repeat_heads, n_repeat_dim, n_heads_old, head_dim_old) 123 | sh_old = copy.copy(w.shape) 124 | w = w.permute(0, 1, 3, 2, 4) 125 | sh_new = copy.copy(w.shape) 126 | w = w.reshape(w.shape[0], -1) 127 | return w 128 | 129 | 130 | def clone_olmo_qkv_layer( 131 | dst_layer, src_layer, num_heads_dst, num_heads_src, snr_db=None 132 | ): 133 | """ 134 | Clones a source QKV linear layer into a destination QKV linear layer. 135 | 136 | Arguments: 137 | dst_layer: Destination layer. 138 | src_layer: Source layer. 139 | num_heads_dst: Number of attention heads in the destination layer. 140 | num_heads_src: Number of attention heads in the source layer. 141 | snr_db: Signal-to-noise ratio. Defaults to None. 142 | 143 | Returns: 144 | None. 145 | """ 146 | 147 | dst_layer.weight.data = clone_olmo_qkv_weight( 148 | dst_layer.weight.shape, 149 | src_layer.weight.data, 150 | num_heads_dst, 151 | num_heads_src, 152 | snr_db=snr_db, 153 | ) 154 | if src_layer.bias is not None: 155 | assert ( 156 | dst_layer.bias is not None 157 | ), "source model has bias in it's linear layers but destination model doesn't" 158 | dst_layer.bias.data = clone_olmo_qkv_bias( 159 | dst_layer.bias.shape, src_layer.bias.data, num_heads_dst, num_heads_src 160 | ) 161 | 162 | 163 | def clone_olmo_qkv_weight( 164 | dst_weight_shape, src_weight, num_heads_dst, num_heads_src, snr_db=None 165 | ): 166 | """ 167 | Clones 'src_weight' into a weight tensor with 'dst_weight_shape' to be 168 | used in the attention layer. 169 | 170 | Arguments: 171 | dst_weight_shape: 172 | Shape of the weight tensor in the destination layer. 173 | src_weight: 174 | Weight tensor in the source layer. 175 | num_heads_dst: 176 | Number of attention heads in the destination layer. 177 | num_heads_src: 178 | Number of attention heads in the source layer. 179 | snr_db: 180 | Signal-to-noise ratio. Defaults to None. 181 | 182 | Returns: 183 | Cloned QKV weights. 184 | """ 185 | assert src_weight.shape[0] == (3 * src_weight.shape[1]) 186 | assert dst_weight_shape[0] == (3 * dst_weight_shape[1]) 187 | source_embedding_dim = src_weight.shape[1] 188 | destination_embedding_dim = dst_weight_shape[1] 189 | n_repeat = destination_embedding_dim // source_embedding_dim 190 | dst_weight = src_weight.reshape( 191 | 3, num_heads_src, source_embedding_dim // num_heads_src, source_embedding_dim 192 | ) # (3, H, E/H, E) 193 | block_shape = dst_weight.shape 194 | head_repeat = num_heads_dst // num_heads_src 195 | dim_repeat = n_repeat // head_repeat 196 | dst_weight = ( 197 | dst_weight.repeat(1, head_repeat, dim_repeat, n_repeat) / n_repeat 198 | ) # (3, nH, E/H, nE) 199 | dst_weight[:2] = dst_weight[:2] / np.sqrt( 200 | np.sqrt(dim_repeat) 201 | ) ##divide query and key weights to compensate for normalization 202 | if snr_db is not None: 203 | dst_weight = add_noise(dst_weight, block_shape, snr_db) 204 | dst_weight = dst_weight.reshape( 205 | 3 * destination_embedding_dim, destination_embedding_dim 206 | ) # (3, n_heads, d_head, e) --> #(3*n_heads*d_head, e) 207 | 208 | return dst_weight 209 | 210 | 211 | def clone_olmo_qkv_bias(dst_bias_shape, src_bias, num_heads_dst, num_heads_src): 212 | assert False, "not implemented" 213 | 214 | 215 | def clone_olmo( 216 | src_network, 217 | embedding_dim_multiplier: int = 1, 218 | up_project_multiplier: int = 1, 219 | **kwargs, 220 | ): 221 | """ 222 | Clones the OLMo network. See hypercloning.cloneModel for argument descriptions. 223 | 224 | Returns: 225 | Cloned OLMo network instance. 226 | """ 227 | 228 | # Check if user has specified num_heads_multiplier manually. If not, set it 229 | # to 'embedding_dim_multiplier'. 230 | num_heads_multiplier = kwargs.get("num_heads_multiplier", embedding_dim_multiplier) 231 | snr_db = kwargs.get("snr_db", None) 232 | # Set the destination network config according to user requested expansion factors: 233 | config = copy.deepcopy(src_network.config) 234 | setattr(config, "d_model", embedding_dim_multiplier * config.d_model) 235 | if getattr(config, "mlp_hidden_size", None) is None: 236 | setattr( 237 | config, 238 | "mlp_hidden_size", 239 | src_network.config.d_model * src_network.config.mlp_ratio, 240 | ) 241 | setattr(config, "mlp_hidden_size", up_project_multiplier * config.mlp_hidden_size) 242 | setattr(config, "n_heads", num_heads_multiplier * config.n_heads) 243 | 244 | # rename the config according to expansion factors 245 | config = rename_config(config, embedding_dim_multiplier, up_project_multiplier) 246 | 247 | # lazy approach: disable weight tying to avoid mismatch 248 | config.weight_tying = False 249 | 250 | old_head_dim = src_network.config.d_model // src_network.config.n_heads 251 | new_head_dim = config.d_model // config.n_heads 252 | 253 | assert ( 254 | new_head_dim % old_head_dim == 0 255 | ), f"new head dimension ({new_head_dim}) should divide original head \ 256 | dimension ({old_head_dim}). Consider changing num_heads_multiplier \ 257 | ({num_heads_multiplier})" 258 | dst_network = AutoModelForCausalLM.from_config( 259 | config, trust_remote_code=True, torch_dtype=torch.bfloat16 260 | ) 261 | 262 | # Clone the embedding layer: 263 | dst_network.model.transformer.wte.weight.data = clone_matrix( 264 | dst_network.model.transformer.wte.weight.data.shape, 265 | src_network.model.transformer.wte.weight.data, 266 | normalize=False, 267 | ) 268 | 269 | # Clone layernorm: 270 | clone_layer_norm( 271 | dst_network.model.transformer.ln_f, src_network.model.transformer.ln_f 272 | ) 273 | 274 | # Iterate through decoder layers and clone their components: 275 | for dst_layer, src_layer in zip( 276 | dst_network.model.transformer.blocks, src_network.model.transformer.blocks 277 | ): 278 | clone_linear_layer(dst_layer.attn_out, src_layer.attn_out, snr_db=snr_db) 279 | 280 | # If head_dim has changed, we need to re-order the weights of dst_layer.attn_out 281 | # to match the expansion in attention heads: 282 | if new_head_dim != old_head_dim: 283 | dst_layer.attn_out.weight.data = reorder_weights( 284 | w=dst_layer.attn_out.weight.data, 285 | n_heads_old=src_layer.config.n_heads, 286 | head_dim_old=old_head_dim, 287 | n_repeat_dim=new_head_dim // old_head_dim, 288 | n_repeat_heads=num_heads_multiplier, 289 | ) 290 | 291 | clone_linear_layer(dst_layer.ff_out, src_layer.ff_out, snr_db=snr_db) 292 | clone_layer_norm(dst_layer.attn_norm, src_layer.attn_norm) 293 | clone_layer_norm(dst_layer.ff_norm, src_layer.ff_norm) 294 | clone_olmo_qkv_layer( 295 | dst_layer.att_proj, 296 | src_layer.att_proj, 297 | dst_layer.config.n_heads, 298 | src_layer.config.n_heads, 299 | snr_db=snr_db, 300 | ) 301 | clone_linear_layer(dst_layer.ff_proj, src_layer.ff_proj, snr_db=snr_db) 302 | 303 | # If the input to SWIGLU activaion has changed dimension, we should reorder weights: 304 | swiglu_repeat = dst_layer.ff_proj.out_features // ( 305 | src_layer.ff_proj.out_features 306 | ) 307 | if swiglu_repeat > 1: 308 | dst_layer.ff_proj.weight.data = reorder_swiglu( 309 | dst_layer.ff_proj.weight.data, swiglu_repeat 310 | ) 311 | if dst_layer.ff_proj.bias is not None: 312 | dst_layer.ff_proj.bias.data = reorder_swiglu( 313 | dst_layer.ff_proj.bias.data, swiglu_repeat 314 | ) 315 | 316 | # If the head dimension has changed, we should fix positional embedding: 317 | if new_head_dim > old_head_dim: 318 | dst_layer.rotary_emb = clonedPositionalEmbedding( 319 | src_layer.rotary_emb, new_head_dim // old_head_dim 320 | ) 321 | 322 | if not src_network.config.weight_tying: 323 | # Clone the unembedding layer from the source network unembedding: 324 | dst_network.model.transformer.ff_out.weight.data = clone_matrix( 325 | dst_network.model.transformer.ff_out.weight.data.shape, 326 | src_network.model.transformer.ff_out.weight.data, 327 | normalize=True, 328 | ) 329 | else: 330 | # Clone the unembedding layer from the source network embedding layer: 331 | dst_network.model.transformer.ff_out.weight.data = clone_matrix( 332 | dst_network.model.transformer.ff_out.weight.data.shape, 333 | src_network.model.transformer.wte.weight.data, 334 | normalize=True, 335 | ) 336 | return dst_network 337 | -------------------------------------------------------------------------------- /hypercloning/opt_cloning.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2020 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import copy 7 | 8 | import numpy as np 9 | import torch 10 | from transformers import OPTForCausalLM 11 | 12 | from hypercloning.common import (clone_layer_norm, clone_linear_layer, 13 | clone_matrix, rename_config, 14 | scale_linear_layer, scaledLinear) 15 | 16 | 17 | def clone_positional_embedding_layer(dst_layer, src_layer): 18 | """ 19 | Clones the parameters of positional embedding from 'src_layer' to 'dst_layer'. 20 | 21 | Arguments: 22 | dst_layer: Destination layer. 23 | src_layer: Source (pretrained) layer. 24 | 25 | Returns: 26 | None. 27 | """ 28 | 29 | src_weight = src_layer.weight.data 30 | dst_weight_shape = dst_layer.weight.shape 31 | assert src_weight.shape[1] <= dst_weight_shape[1] 32 | assert src_weight.shape[0] == dst_weight_shape[0] 33 | assert dst_weight_shape[1] % src_weight.shape[1] == 0 34 | n_repeat = dst_weight_shape[1] // src_weight.shape[1] 35 | dst_layer.weight.data = src_weight.repeat(1, n_repeat) 36 | 37 | 38 | def clone_opt( 39 | src_network, 40 | embedding_dim_multiplier: int = 1, 41 | up_project_multiplier: int = 1, 42 | **kwargs, 43 | ): 44 | """ 45 | Cloning function for the OPT family. 46 | 47 | For arguments description, refer to hypercloning.cloneModel. 48 | 49 | Returns: 50 | Cloned OPT model instance. 51 | """ 52 | 53 | # Check if user has specified num_heads_multiplier manually. If not, set it 54 | # to 'embedding_dim_multiplier'. 55 | num_heads_multiplier = kwargs.get("num_heads_multiplier", embedding_dim_multiplier) 56 | snr_db = kwargs.get("snr_db", None) 57 | # Set the destination network config according to user requested expansion factors: 58 | config = copy.deepcopy(src_network.config) 59 | setattr(config, "hidden_size", embedding_dim_multiplier * config.hidden_size) 60 | setattr(config, "ffn_dim", up_project_multiplier * config.ffn_dim) 61 | setattr( 62 | config, "num_attention_heads", num_heads_multiplier * config.num_attention_heads 63 | ) 64 | 65 | old_head_dim = ( 66 | src_network.config.hidden_size // src_network.config.num_attention_heads 67 | ) 68 | new_head_dim = config.hidden_size // config.num_attention_heads 69 | 70 | # rename the config according to expansion factors 71 | config = rename_config(config, embedding_dim_multiplier, up_project_multiplier) 72 | 73 | # If head_dim changes, attention query and key layers should be scaled: 74 | attention_scaler = np.sqrt(np.sqrt(old_head_dim * 1.0 / new_head_dim)) 75 | 76 | assert ( 77 | new_head_dim % old_head_dim == 0 78 | ), f"new head dimension ({new_head_dim}) should divide original head dimension \ 79 | ({old_head_dim}). Consider changing num_heads_multiplier ({num_heads_multiplier})" 80 | 81 | # Make an instance of the destination network: 82 | dst_network = OPTForCausalLM._from_config(config) 83 | 84 | # Set the embedding layer parameters: 85 | dst_network.model.decoder.embed_tokens.weight.data = clone_matrix( 86 | dst_network.model.decoder.embed_tokens.weight.data.shape, 87 | src_network.model.decoder.embed_tokens.weight.data, 88 | normalize=False, 89 | ) 90 | 91 | # If the network uses 'project_in' and 'project_out', clone these parameters: 92 | if dst_network.model.decoder.project_in is not None: 93 | clone_linear_layer( 94 | dst_network.model.decoder.project_in, 95 | src_network.model.decoder.project_in, 96 | snr_db=snr_db, 97 | ) 98 | clone_linear_layer( 99 | dst_network.model.decoder.project_out, 100 | src_network.model.decoder.project_out, 101 | snr_db=snr_db, 102 | ) 103 | 104 | # Clone the positional embedding layer parameters: 105 | clone_positional_embedding_layer( 106 | dst_network.model.decoder.embed_positions, 107 | src_network.model.decoder.embed_positions, 108 | ) 109 | 110 | # Clone the final layer norm if required: 111 | if src_network.model.decoder.final_layer_norm is not None: 112 | clone_layer_norm( 113 | dst_network.model.decoder.final_layer_norm, 114 | src_network.model.decoder.final_layer_norm, 115 | ) 116 | 117 | # Iterate through the decoder layers and clone the components: 118 | for dst_layer, src_layer in zip( 119 | dst_network.model.decoder.layers, src_network.model.decoder.layers 120 | ): 121 | clone_linear_layer( 122 | dst_layer.self_attn.k_proj, src_layer.self_attn.k_proj, snr_db=snr_db 123 | ) 124 | scale_linear_layer(dst_layer.self_attn.k_proj, attention_scaler) 125 | clone_linear_layer( 126 | dst_layer.self_attn.q_proj, src_layer.self_attn.q_proj, snr_db=snr_db 127 | ) 128 | scale_linear_layer(dst_layer.self_attn.q_proj, attention_scaler) 129 | clone_linear_layer( 130 | dst_layer.self_attn.v_proj, src_layer.self_attn.v_proj, snr_db=snr_db 131 | ) 132 | clone_linear_layer( 133 | dst_layer.self_attn.out_proj, src_layer.self_attn.out_proj, snr_db=snr_db 134 | ) 135 | clone_layer_norm(dst_layer.self_attn_layer_norm, src_layer.self_attn_layer_norm) 136 | clone_linear_layer(dst_layer.fc1, src_layer.fc1, snr_db=snr_db) 137 | clone_linear_layer(dst_layer.fc2, src_layer.fc2, snr_db=snr_db) 138 | clone_layer_norm(dst_layer.final_layer_norm, src_layer.final_layer_norm) 139 | 140 | # Note: the unembedding layer is tied with the embedding layer. We need to divide the 141 | # weights of the unembedding layer by 'embedding_dim_multiplier' but the weights in the 142 | # embedding layer should not be divided. Therefore, we use a wrapper class around lm_head: 143 | if embedding_dim_multiplier > 1: 144 | dst_network.lm_head = scaledLinear( 145 | dst_network.lm_head, 1.0 / embedding_dim_multiplier 146 | ) 147 | 148 | return dst_network 149 | -------------------------------------------------------------------------------- /hypercloning/pythia_cloning.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2020 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import copy 7 | 8 | import numpy as np 9 | import torch 10 | from transformers import (AutoModelForCausalLM, AutoTokenizer, GPTNeoXConfig, 11 | GPTNeoXForCausalLM) 12 | 13 | from hypercloning.common import (add_noise, clone_layer_norm, 14 | clone_linear_layer, clone_matrix, 15 | rename_config) 16 | 17 | 18 | def clone_pythia_qkv_layer( 19 | dst_layer, src_layer, num_heads_dst, num_heads_src, snr_db=None 20 | ): 21 | """ 22 | Clones a source QKV linear layer into a destination QKV linear layer. 23 | 24 | Arguments: 25 | dst_layer: Destination layer. 26 | src_layer: Source layer. 27 | num_heads_dst: Number of attention heads in the destination layer. 28 | num_heads_src: Number of attention heads in the source layer. 29 | snr_db: Signal-to-noise ratio. Defaults to None. 30 | 31 | Returns: 32 | None. 33 | """ 34 | 35 | dst_layer.weight.data = clone_pythia_qkv_weight( 36 | dst_layer.weight.shape, 37 | src_layer.weight.data, 38 | num_heads_dst, 39 | num_heads_src, 40 | snr_db=snr_db, 41 | ) 42 | if src_layer.bias is not None: 43 | assert ( 44 | dst_layer.bias is not None 45 | ), "source model has bias in it's linear layers but destination model doesn't" 46 | dst_layer.bias.data = clone_pythia_qkv_bias( 47 | dst_layer.bias.shape, src_layer.bias.data, num_heads_dst, num_heads_src 48 | ) 49 | 50 | 51 | def clone_pythia_qkv_weight( 52 | dst_weight_shape, src_weight, num_heads_dst, num_heads_src, snr_db=None 53 | ): 54 | """ 55 | Clones 'src_weight' into a weight tensor with 'dst_weight_shape' to be 56 | used in the attention layer. 57 | 58 | Arguments: 59 | dst_weight_shape: 60 | Shape of the weight tensor in the destination layer. 61 | src_weight: 62 | Weight tensor in the source layer. 63 | num_heads_dst: 64 | Number of attention heads in the destination layer. 65 | num_heads_src: 66 | Number of attention heads in the source layer. 67 | snr_db: 68 | Signal-to-noise ratio. Defaults to None. 69 | 70 | Returns: 71 | Cloned QKV weights. 72 | """ 73 | 74 | assert src_weight.shape[0] == (3 * src_weight.shape[1]) 75 | assert dst_weight_shape[0] == (3 * dst_weight_shape[1]) 76 | source_embedding_dim = src_weight.shape[1] 77 | destination_embedding_dim = dst_weight_shape[1] 78 | n_repeat = destination_embedding_dim // source_embedding_dim 79 | if num_heads_dst // num_heads_src == n_repeat: 80 | assert ( 81 | source_embedding_dim // num_heads_src 82 | == destination_embedding_dim // num_heads_dst 83 | ), "either number of heads or head-dims should be the same in source and \ 84 | cloned network. Cannot change both!" 85 | dst_weight = src_weight.reshape( 86 | num_heads_src, 87 | 3, 88 | source_embedding_dim // num_heads_src, 89 | source_embedding_dim, 90 | ) 91 | block_shape = dst_weight.shape 92 | dst_weight = dst_weight.repeat(n_repeat, 1, 1, n_repeat) / n_repeat 93 | else: 94 | assert ( 95 | num_heads_src == num_heads_dst 96 | ), "either number of heads or head-dims should be the same in source and \ 97 | cloned network. Cannot change both!" 98 | dst_weight = src_weight.reshape( 99 | 3, 100 | num_heads_src, 101 | source_embedding_dim // num_heads_src, 102 | source_embedding_dim, 103 | ) 104 | dst_weight = dst_weight.repeat(1, 1, n_repeat, n_repeat) / n_repeat 105 | block_shape = dst_weight.shape 106 | if snr_db is not None: 107 | dst_weight = add_noise(dst_weight, block_shape, snr_db) 108 | dst_weight = dst_weight.reshape( 109 | 3 * destination_embedding_dim, destination_embedding_dim 110 | ) # (3, n_heads, d_head, e) --> #(3*n_heads*d_head, e) 111 | 112 | return dst_weight 113 | 114 | 115 | def clone_pythia_qkv_bias(dst_bias_shape, src_bias, num_heads_dst, num_heads_src): 116 | """ 117 | Clones 'src_bias' into a bias tensor with 'dst_bias_shape' to be used 118 | in the attention layer. 119 | 120 | Arguments: 121 | dst_bias_shape: 122 | Shape of the bias tensor in the destination layer. 123 | src_bias: 124 | Bias tensor in the source layer. 125 | num_heads_dst: 126 | Number of attention heads in the destination layer. 127 | num_heads_src: 128 | Number of attention heads in the source layer. 129 | 130 | Returns: 131 | Cloned QKV bias. 132 | """ 133 | 134 | source_embedding_dim = src_bias.shape[0] // 3 135 | destination_embedding_dim = dst_bias_shape[0] // 3 136 | n_repeat = destination_embedding_dim // source_embedding_dim 137 | if num_heads_dst // num_heads_src == n_repeat: 138 | dst_bias = src_bias.reshape( 139 | num_heads_src, 3, source_embedding_dim // num_heads_src 140 | ) 141 | dst_bias = dst_bias.repeat(n_repeat, 1, 1) 142 | else: 143 | dst_bias = src_bias.reshape( 144 | 3, num_heads_src, source_embedding_dim // num_heads_src 145 | ) 146 | dst_bias = dst_bias.repeat(1, 1, n_repeat) 147 | return dst_bias.reshape(-1) 148 | 149 | 150 | def clone_GPTNeoXAttention(dst_layer, src_layer, snr_db=None): 151 | """ 152 | Clones the attention layer from 'src_layer' into 'dst_layer'. 153 | 154 | Arguments: 155 | dst_layer: Destination attention layer. 156 | src_layer: Source (pretrained) attention layer. 157 | 158 | Returns: 159 | None. 160 | """ 161 | 162 | clone_pythia_qkv_layer( 163 | dst_layer.query_key_value, 164 | src_layer.query_key_value, 165 | dst_layer.num_attention_heads, 166 | src_layer.num_attention_heads, 167 | snr_db=snr_db, 168 | ) 169 | clone_linear_layer(dst_layer.dense, src_layer.dense, snr_db=snr_db) 170 | 171 | 172 | def clone_pythia( 173 | src_network, 174 | embedding_dim_multiplier: int = 1, 175 | up_project_multiplier: int = 1, 176 | **kwargs, 177 | ): 178 | """ 179 | Clones Pythia models. 180 | 181 | For arguments description, refer to hypercloning.cloneModel. 182 | 183 | Returns: 184 | Cloned Pythia model instance. 185 | """ 186 | num_heads_multiplier = kwargs.get("num_heads_multiplier", embedding_dim_multiplier) 187 | snr_db = kwargs.get("snr_db", None) 188 | assert ( 189 | num_heads_multiplier == embedding_dim_multiplier 190 | ), "head_dim expansion is not supported for Pythia. The number of heads will \ 191 | be automatically computed based on ebedding dimension expansion. Do not \ 192 | pass 'num_heads_multiplier' to 'clone_pythia'" 193 | config = GPTNeoXConfig(**src_network.config.to_dict()) 194 | 195 | # Set the new config parameters for the destination (expanded) network: 196 | config.hidden_size = embedding_dim_multiplier * config.hidden_size 197 | config.num_attention_heads = num_heads_multiplier * config.num_attention_heads 198 | old_head_dim = ( 199 | src_network.config.hidden_size // src_network.config.num_attention_heads 200 | ) 201 | new_head_dim = config.hidden_size // config.num_attention_heads 202 | config.intermediate_size = up_project_multiplier * config.intermediate_size 203 | 204 | # rename the config according to expansion factors 205 | config = rename_config(config, embedding_dim_multiplier, up_project_multiplier) 206 | 207 | # Create the destination network: 208 | dst_network = GPTNeoXForCausalLM._from_config( 209 | config, 210 | torch_dtype=torch.bfloat16, 211 | attn_implementation=src_network.config._attn_implementation, 212 | ) 213 | 214 | # Clone the embedding layer parameters 215 | dst_network.gpt_neox.embed_in.weight.data = clone_matrix( 216 | dst_network.gpt_neox.embed_in.weight.data.shape, 217 | src_network.gpt_neox.embed_in.weight.data, 218 | normalize=False, 219 | ) 220 | 221 | # Iterate through pairs of layers in source and destination layers. 222 | # Clone source layer components into destination layer components: 223 | for dst_layer, src_layer in zip( 224 | dst_network.gpt_neox.layers, src_network.gpt_neox.layers 225 | ): 226 | clone_layer_norm(dst_layer.input_layernorm, src_layer.input_layernorm) 227 | clone_layer_norm( 228 | dst_layer.post_attention_layernorm, src_layer.post_attention_layernorm 229 | ) 230 | clone_GPTNeoXAttention(dst_layer.attention, src_layer.attention, snr_db=snr_db) 231 | clone_linear_layer( 232 | dst_layer.mlp.dense_h_to_4h, src_layer.mlp.dense_h_to_4h, snr_db=snr_db 233 | ) 234 | clone_linear_layer( 235 | dst_layer.mlp.dense_4h_to_h, src_layer.mlp.dense_4h_to_h, snr_db=snr_db 236 | ) 237 | 238 | # Clone the final layer norm: 239 | clone_layer_norm( 240 | dst_network.gpt_neox.final_layer_norm, src_network.gpt_neox.final_layer_norm 241 | ) 242 | 243 | # Clone the unembedding layer: 244 | clone_linear_layer(dst_network.embed_out, src_network.embed_out) 245 | return dst_network 246 | -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-hypercloning/7c4e80bfc5329980c50946cd9baab0bb3109cc68/images/teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | transformers 4 | ai2-olmo 5 | datasets --------------------------------------------------------------------------------