├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── pull_request_template.md ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmark.py ├── benchmarks ├── llama1b │ ├── benchmark_cpu_l1.log │ └── benchmark_gpu.log ├── llama3b │ └── benchmark_cpu_l3.log └── sparsity.json ├── configs ├── base.yml ├── deepseek_r1_distill_qwen_skip_1.5b.json ├── gemma3n_skip_causal_e2b.json ├── llama_skip_causal_1b.json ├── llama_skip_causal_1b_predictor_training.json ├── llama_skip_causal_3b.json ├── llama_skip_causal_3b_predictor_training.json ├── mistral_skip_causal_7b.json ├── opt_skip_causal_2.7b.json ├── opt_skip_causal_2.7b_predictor_training.json ├── phi3_skip_causal_3.8b.json ├── qwen2_skip_causal_1.5b.json ├── relullama_skip_causal_7b.json └── relullama_skip_causal_7b_predictor_training.json ├── downstream_eval.py ├── generate_dataset.py ├── measure_gt_sparsity.py ├── pyproject.toml ├── requirements.txt ├── setup.py ├── sparse_transformers ├── CMakeLists.txt ├── __init__.py └── csrc │ ├── .clang-format │ ├── approx_topk.h │ ├── sparse_mlp_cuda.cu │ ├── sparse_mlp_op.cpp │ └── weight_cache.h ├── src ├── __init__.py ├── activation_capture.py ├── configuration_skip.py ├── modeling_skip.py ├── models │ ├── __init__.py │ ├── gemma3n │ │ ├── __init__.py │ │ ├── configuration_gemma_skip.py │ │ └── modelling_gemma_skip.py │ ├── llama │ │ ├── __init__.py │ │ ├── configuration_llama_skip.py │ │ └── modelling_llama_skip.py │ ├── mistral │ │ ├── __init__.py │ │ ├── configuration_mistral_skip.py │ │ └── modelling_mistral_skip.py │ ├── opt │ │ ├── __init__.py │ │ ├── activation_capture_opt.py │ │ ├── configuration_opt_skip.py │ │ └── modelling_opt_skip.py │ ├── phi3 │ │ ├── __init__.py │ │ ├── configuration_phi_skip.py │ │ └── modelling_phi_skip.py │ └── qwen2 │ │ ├── __init__.py │ │ ├── configuration_qwen_skip.py │ │ └── modelling_qwen_skip.py ├── trainer.py └── utilities │ ├── __init__.py │ ├── cuda_utils.py │ ├── logger.py │ ├── random.py │ ├── registry.py │ ├── saver.py │ └── sys_utils.py ├── train.py └── train_parallel.sh /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Report a bug to help us improve sparse transformers 4 | title: '[BUG] ' 5 | labels: bug 6 | assignees: '' 7 | --- 8 | 9 | **Describe the Bug** 10 | Include description of what the bug is and add steps to reproduce this behavior. 11 | 12 | **Expected Behavior** 13 | A clear and concise description of what you expected to happen. 14 | 15 | **Additional Information** 16 | Add any other context about the problem here. 17 | 18 | **Possible Solution** 19 | If you have suggestions on how to fix the issue, please describe it here. 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Suggest an idea to improve sparse transformers 4 | title: '[FEATURE] ' 5 | labels: enhancement 6 | assignees: '' 7 | --- 8 | 9 | **Describe the feature request** 10 | 11 | **Describe the solution you'd like** 12 | 13 | **Describe alternatives you've considered** 14 | 15 | **Additional context** 16 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | Please provide a description of this PR. 3 | 4 | Fixes # (issue) 5 | 6 | ## Checklist: 7 | - [ ] I have added tests that prove my fix is effective or that my feature works 8 | - [ ] Has user-facing changes. This may include API or behavior changes and performance improvments, etc 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | **/__pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | llama.cpp/ 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | .cursorrules 24 | trained_predictors/ 25 | wandb/ 26 | data/ 27 | logs/ 28 | # CUDA 29 | *.i 30 | *.ii 31 | *.gpu 32 | *.ptx 33 | *.cubin 34 | *.fatbin 35 | *.o 36 | *.obj 37 | *.pkl 38 | *.png 39 | # IDE specific files 40 | .idea/ 41 | .vscode/ 42 | *.swp 43 | *.swo 44 | .project 45 | .pydevproject 46 | .settings/ 47 | .vs/ 48 | 49 | # Environment 50 | .env 51 | .venv 52 | env/ 53 | venv/ 54 | ENV/ 55 | env.bak/ 56 | venv.bak/ 57 | 58 | # Distribution / packaging 59 | .Python 60 | build/ 61 | develop-eggs/ 62 | dist/ 63 | downloads/ 64 | eggs/ 65 | .eggs/ 66 | lib/ 67 | lib64/ 68 | parts/ 69 | sdist/ 70 | var/ 71 | wheels/ 72 | *.egg-info/ 73 | .installed.cfg 74 | *.egg 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # Testing 80 | htmlcov/ 81 | .tox/ 82 | .coverage 83 | .coverage.* 84 | .cache 85 | nosetests.xml 86 | coverage.xml 87 | *.cover 88 | .hypothesis/ 89 | 90 | # Logs and databases 91 | *.sqlite 92 | *.db 93 | data/ 94 | 95 | # OS generated files 96 | .DS_Store 97 | .DS_Store? 98 | ._* 99 | .Spotlight-V100 100 | .Trashes 101 | ehthumbs.db 102 | Thumbs.db 103 | build.log 104 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official email address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at . 63 | All complaints will be reviewed and investigated promptly and fairly. 64 | 65 | All community leaders are obligated to respect the privacy and security of the 66 | reporter of any incident. 67 | 68 | ## Enforcement Guidelines 69 | 70 | Community leaders will follow these Community Impact Guidelines in determining 71 | the consequences for any action they deem in violation of this Code of Conduct: 72 | 73 | ### 1. Correction 74 | 75 | **Community Impact**: Use of inappropriate language or other behavior deemed 76 | unprofessional or unwelcome in the community. 77 | 78 | **Consequence**: A private, written warning from community leaders, providing 79 | clarity around the nature of the violation and an explanation of why the 80 | behavior was inappropriate. A public apology may be requested. 81 | 82 | ### 2. Warning 83 | 84 | **Community Impact**: A violation through a single incident or series of 85 | actions. 86 | 87 | **Consequence**: A warning with consequences for continued behavior. No 88 | interaction with the people involved, including unsolicited interaction with 89 | those enforcing the Code of Conduct, for a specified period of time. This 90 | includes avoiding interactions in community spaces as well as external channels 91 | like social media. Violating these terms may lead to a temporary or permanent 92 | ban. 93 | 94 | ### 3. Temporary Ban 95 | 96 | **Community Impact**: A serious violation of community standards, including 97 | sustained inappropriate behavior. 98 | 99 | **Consequence**: A temporary ban from any sort of interaction or public 100 | communication with the community for a specified period of time. No public or 101 | private interaction with the people involved, including unsolicited interaction 102 | with those enforcing the Code of Conduct, is allowed during this period. 103 | Violating these terms may lead to a permanent ban. 104 | 105 | ### 4. Permanent Ban 106 | 107 | **Community Impact**: Demonstrating a pattern of violation of community 108 | standards, including sustained inappropriate behavior, harassment of an 109 | individual, or aggression toward or disparagement of classes of individuals. 110 | 111 | **Consequence**: A permanent ban from any sort of public interaction within the 112 | community. 113 | 114 | ## Attribution 115 | 116 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 117 | version 2.1, available at 118 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 119 | 120 | Community Impact Guidelines were inspired by 121 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 122 | 123 | For answers to common questions about this code of conduct, see the FAQ at 124 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 125 | [https://www.contributor-covenant.org/translations][translations]. 126 | 127 | [homepage]: https://www.contributor-covenant.org 128 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 129 | [Mozilla CoC]: https://github.com/mozilla/diversity 130 | [FAQ]: https://www.contributor-covenant.org/faq 131 | [translations]: https://www.contributor-covenant.org/translations 132 | 133 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Sparse Transformers 2 | 3 | Thank you for your interest in contributing to Sparse Transformers! This document provides guidelines and instructions for contributing to this project. 4 | 5 | ## Development Workflow 6 | 7 | We follow a fork and pull request workflow for all contributions. Here's how it works: 8 | 9 | 1. Fork the repository 10 | 2. Create a new branch for your feature/fix 11 | 3. Make your changes 12 | 4. Submit a pull request 13 | 14 | ### Detailed Steps 15 | 16 | 1. **Fork the Repository** 17 | - Click the "Fork" button on the top right of the repository page 18 | - Clone your fork locally: 19 | ```bash 20 | git clone https://github.com/YOUR-USERNAME/sparse_transformers.git 21 | cd Sparse Transformers 22 | ``` 23 | 24 | 2. **Create a Branch** 25 | - Create a new branch for your changes: 26 | ```bash 27 | git checkout -b feature/your-feature-name 28 | ``` 29 | 30 | 3. **Make Changes** 31 | - Make your changes and commit them 32 | - Write clear commit messages 33 | - Test your changes thoroughly 34 | 35 | 4. **Submit a Pull Request** 36 | - Push your branch to your fork 37 | - Create a pull request against the main repository 38 | - Fill out the pull request template completely 39 | 40 | ## Developer Certificate of Origin (DCO) 41 | 42 | This project requires all contributors to sign off on their commits. This is done through the Developer Certificate of Origin (DCO). The DCO is a lightweight way for contributors to certify that they wrote or otherwise have the right to submit the code they are contributing to the project. 43 | 44 | ### How to Sign Off 45 | 46 | Each commit message must include a Signed-off-by line with your name and email address. You can add this automatically using the `-s` flag when committing: 47 | 48 | ```bash 49 | git commit -s -m "Your commit message" 50 | ``` 51 | 52 | The sign-off line should look like this: 53 | ``` 54 | Signed-off-by: Your Name 55 | ``` 56 | 57 | For more information about the DCO, please visit [DCO App Documentation](https://github.com/dcoapp/app#how-it-works). 58 | 59 | ## Pull Request Requirements 60 | 61 | 1. **Code Quality** 62 | - Follow the existing code style 63 | - Write clear, maintainable code 64 | - Include appropriate tests 65 | - Update documentation as needed 66 | 67 | 2. **Commit Messages** 68 | - Use clear, descriptive commit messages 69 | - Include the DCO sign-off in each commit 70 | - Reference any related issues 71 | 72 | 3. **Pull Request Description** 73 | - Clearly describe the changes 74 | - Reference any related issues 75 | - Ensure all tests and checks are passing 76 | 77 | ## Getting Help 78 | 79 | If you need help or have questions: 80 | - Open an issue 81 | - Join our community discussions 82 | - Reach out to the maintainers 83 | 84 | Thank you for contributing to Sparse Transformers! 85 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Join us on Discord](https://img.shields.io/badge/Join%20us-Discord-5865F2?logo=discord&logoColor=white)](https://discord.gg/y8WkMncstk) 2 | 3 | # Fused Sparse C++ Kernels for Transformers 4 | 5 | ## Overview 6 | 7 | The project implements sparse multiplication and fuses up/down projections in the MLP layers through low rank weight activations. 8 | Work is based on [Deja Vu](https://arxiv.org/abs/2310.17157) and Apple's [LLM in a Flash](https://arxiv.org/abs/2312.11514). 9 | 10 | ### Benefits 11 | - **1.6-1.8x overall gain in TTFT and TPS** (4-5x gain in MLP Inference) 12 | - **26.4%** reduction in memory usage 13 | - **6.7×** faster index selection and replacement for weight caching 14 | 15 | 16 | ``` 17 | ┌─────────────────────────────────────────────────────────────────┐ 18 | │ Sparse LLM Inference Pipeline │ 19 | ├─────────────────────────────────────────────────────────────────┤ 20 | │ Sparsity Selection │ 21 | │ ├─ Hidden States → LoRA Projection (Importance Scoring) │ 22 | │ ├─ Binary Mask Generation: (scores > threshold) │ 23 | │ └─ Mask Normalization: Union across batch dimension │ 24 | ├─────────────────────────────────────────────────────────────────┤ 25 | │ Differential Weight Caching │ 26 | │ ├─ Mask Change Detection: XOR with previous mask │ 27 | │ ├─ Paired Replacement: Direct substitution algorithm │ 28 | │ └─ Zero-Copy Tensor Views: torch::from_blob references │ 29 | ├─────────────────────────────────────────────────────────────────┤ 30 | │ Sparse Computation │ 31 | │ ├─ Concatenated Gate+Up Projection (Fused Operation) │ 32 | │ ├─ Element-wise Activation: σ(gate) ⊙ up │ 33 | │ └─ Sparse Down Projection: Only active intermediate dims │ 34 | └─────────────────────────────────────────────────────────────────┘ 35 | ``` 36 | 37 | **Keywords:** Large Language Models, Sparse Inference, Differential Weight Caching 38 | 39 | ## Performance Benchmarks 40 | State of Implementation: 41 | - [x] Torch CPU kernels for fp16, fp32 42 | - [x] Differential weight caching and selection for dynamic sparsity 43 | - [ ] CUDA kernels for Sparse Inferencing 44 | - [ ] CPU kernels for int8, int32, int64 45 | 46 | ### CPU Performance 47 | ``` 48 | Sparse LLaMA 3.2 3B vs LLaMA 3.2 3B (on HuggingFace Implementation): 49 | 50 | - Time to First Token (TTFT): 1.51× faster (1.209s → 0.803s) 51 | - Output Generation Speed: 1.79× faster (0.7 → 1.2 tokens/sec) 52 | - Total Throughput: 1.78× faster (0.7 → 1.3 tokens/sec) 53 | - Memory Usage: 26.4% reduction (13.25GB → 9.75GB) 54 | ``` 55 | 56 | ### GPU Performance 57 | 58 | ``` 59 | Sparse LLaMA 3.2 3B vs Standard LLaMA 3.2 3B CUDA Results (on HuggingFace Implementation): 60 | 61 | - Average time (Sparse): 0.021s 62 | - Average time (Standard): 0.018s 63 | - CUDA Speedups: 0.86x (WIP) 64 | ``` 65 | 66 | ## Usage 67 | 68 | ### Quick Benchmark 69 | 70 | ```bash 71 | # Run comprehensive benchmark 72 | 73 | python benchmark.py \ 74 | --device cpu \ # Device: 'cpu' or 'cuda' 75 | --config configs/llama_skip_causal_3b.json \ # Model configuration 76 | --num_runs 50 \ # Number of benchmark runs 77 | --verbose True # Detailed timing output 78 | 79 | # Expected output: 80 | # ⚡ TTFT Speedup: 1.51x 81 | # 🚀 Output TPS Speedup: 1.79x 82 | # 📊 Total Throughput Speedup: 1.78x 83 | ``` 84 | 85 | ## Implementation Details 86 | 87 | ### Paired Replacement with Differential Caching 88 | _sparse_transformers/csrc/weight_cache.h_ 89 | 90 | The weight cache is a class that manages the active weights for the sparse MLP. It differentially updates the MLP tensor memory pool for the next token based on the predicted sparsity mask. 91 | 92 | ```cpp 93 | class WeightCache { 94 | // Paired replacement algorithm for differential updates 95 | void update_active_weights(const torch::Tensor &mask) 96 | 97 | }; 98 | ``` 99 | 100 | **Performance Impact:** 101 | - **6.7× faster cache updates**: 29.89ms (naive `index_select`) → 4.46ms (paired replacement) 102 | - **Better cache locality**: Row major for Up Projection and Column major for Down Projection Matrices 103 | - **Contiguous Memory Access**: Single memcpy for cache updates 104 | 105 | ### Sparse MLP Inference 106 | _sparse_transformers/csrc/sparse_mlp_op.cpp_ 107 | 108 | ```python 109 | sparse_mlp_forward( 110 | x.detach(), 111 | self.weight_cache.get_concat_weight(), 112 | self.weight_cache.get_active_down_weight(), 113 | self.down_proj_buffer, 114 | self.combined_proj_buffer, 115 | "silu" 116 | ) 117 | ``` 118 | 119 | **Performance Impact:** 120 | - **5× faster CPU MLP inference**: 30.1ms → 6.02ms 121 | - OpenMP parallelization with `torch::at::parallel_for` 122 | - Bounded memory usage with weight cache memory pool 123 | 124 | ## Project Structure 125 | 126 | ``` 127 | ├── sparse_transformers/ # C++ extension module 128 | │ ├── csrc/ 129 | │ │ ├── sparse_mlp_op.cpp # Main CPU/CUDA dispatcher 130 | │ │ ├── sparse_mlp_cuda.cu # CUDA kernels 131 | │ │ └── weight_cache.h # Paired replacement caching 132 | │ ├── __init__.py # Python bindings 133 | │ └── CMakeLists.txt # Build configuration 134 | ├── src/models/llama/ 135 | │ ├── modelling_llama_skip.py # Statistical sparsity model 136 | │ └── configuration_llama_skip.py # Model configuration 137 | ├── tools/ 138 | │ └── component_timing.py # Performance profiling 139 | └── run_benchmark.py # End-to-end benchmarks 140 | ``` 141 | 142 | ## Installation 143 | 144 | ### Build C++ Extensions 145 | ```bash 146 | # Clone repository 147 | git clone https://github.com/nimbleedge/sparse_transformers.git 148 | cd sparse_transformers 149 | ``` 150 | 151 | Set up conda environment and install dependencies 152 | ```bash 153 | conda create -n sparse_transformers python=3.10 154 | conda activate sparse_transformers 155 | ``` 156 | 157 | Install torch dependencies from [requirements.txt](requirements.txt#L2) 158 | 159 | ```bash 160 | # Install in editable mode (builds C++ extensions automatically) 161 | pip install -r requirements.txt 162 | pip install -e . # Auto-detect (prefer GPU if available) 163 | pip install -e . --build-option=cpu # Force CPU-only build 164 | pip install -e . --build-option=gpu # Force GPU build (fallback to CPU if not available) 165 | 166 | # Alternative: Direct setup.py commands 167 | python setup.py develop # Auto-detect (prefer GPU if available) 168 | python setup.py develop cpu # Force CPU-only build 169 | python setup.py develop gpu # Force GPU build (fallback to CPU if not available) 170 | 171 | # Verify installation 172 | python -c "import sparse_transformers; print('✅ Installation successful')" 173 | ``` 174 | 175 | ## Community engagement 176 | We welcome any feedback or suggestions - please join our 177 | [Discord](https://discord.gg/y8WkMncstk) to engage with the community. 178 | 179 | ## Contributing 180 | We welcome contributions from the community! Areas of particular interest are: 181 | - **Additional models**: Extend beyond LLaMA to other architectures 182 | - **Quantization**: Combine with INT8/FP16 optimizations 183 | - **Attention Kernels**: Implement Sparse Attention Kernels 184 | 185 | Please read our [Contributing Guidelines](CONTRIBUTING.md) to get started. 186 | 187 | ## License 188 | 189 | This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. 190 | 191 | -------------------------------------------------------------------------------- /benchmarks/llama1b/benchmark_gpu.log: -------------------------------------------------------------------------------- 1 | Device set to use cuda 2 | Device set to use cuda 3 | Configuring for 8 CPU threads 4 | 5 | System Configuration: 6 | -------------------------------------------------- 7 | OS: Linux 5.15.0-1079-azure 8 | CPU: x86_64 9 | Physical cores: 8 10 | Total cores: 8 11 | Max CPU frequency: 0MHz 12 | Current CPU frequency: 2544MHz 13 | RAM: Total=54.92GB, Available=51.71GB (5.8% used) 14 | 15 | GPU Configuration: 16 | -------------------------------------------------- 17 | 18 | GPU 0: Tesla T4 19 | Compute capability: 7.5 20 | Total memory: 15.56GB 21 | Free memory: 15.37GB 22 | Multi processors: 40 23 | 24 | PyTorch version: 2.5.1 25 | CUDA version: 12.4 26 | -------------------------------------------------- 27 | Number of available GPUs: 1 28 | Using devices: cuda, cuda, cuda 29 | 30 | Running CUDA inference benchmarks... 31 | -------------------------------------------------- 32 | Warming up models... 33 | 34 | Model type: 35 | Model device: cuda 36 | Model path: meta-llama/Llama-3.2-1B-Instruct 37 | 38 | Model type: 39 | Model device: cuda 40 | Model path: meta-llama/Llama-3.2-1B-Instruct 41 | 42 | Model type: 43 | Model device: cuda 44 | Model path: meta-llama/Llama-3.2-1B-Instruct 45 | 46 | Model type: 47 | Model device: cuda 48 | Model path: meta-llama/Llama-3.2-1B-Instruct 49 | 50 | SkipLLaMA Scripted CUDA Results: 51 | Average time: 0.021s 52 | Min time: 0.020s 53 | Max time: 0.021s 54 | Individual times: ['0.021s', '0.021s', '0.021s', '0.021s', '0.021s', '0.020s', '0.021s', '0.021s', '0.020s', '0.021s'] 55 | 56 | Standard LLaMA CUDA Results: 57 | Average time: 0.018s 58 | Min time: 0.018s 59 | Max time: 0.018s 60 | Individual times: ['0.018s', '0.018s', '0.018s', '0.018s', '0.018s', '0.018s', '0.018s', '0.018s', '0.018s', '0.018s'] 61 | 62 | CUDA Speedups: 63 | Scripted vs Standard: 0.86x 64 | -------------------------------------------------------------------------------- /benchmarks/llama3b/benchmark_cpu_l3.log: -------------------------------------------------------------------------------- 1 | Configuration: configs/llama_skip_causal_3b.json 2 | 3 | Configuring for 8 CPU threads 4 | 5 | System Configuration: 6 | -------------------------------------------------- 7 | OS: Linux 5.15.0-1089-azure 8 | CPU: x86_64 9 | Physical cores: 8 10 | Total cores: 8 11 | Max CPU frequency: 0MHz 12 | Current CPU frequency: 2546MHz 13 | RAM: Total=54.92GB, Available=45.06GB (18.0% used) 14 | 15 | PyTorch version: 2.5.1 16 | CUDA version: 12.4 17 | -------------------------------------------------- 18 | Using devices: cpu, cpu, cpu 19 | 20 | Loading checkpoint shards: 0%| | 0/2 [00:00", 58 | "torch_dtype": "float16", 59 | "transformers_version": "4.21.0.dev0", 60 | "use_cache": true, 61 | "vocab_size": 50272, 62 | "word_embed_proj_dim": 2560 63 | } -------------------------------------------------------------------------------- /configs/opt_skip_causal_2.7b_predictor_training.json: -------------------------------------------------------------------------------- 1 | { 2 | "sparsities": [ 3 | 0.9851281009614468, 4 | 0.9946803697384894, 5 | 0.9953532270155847, 6 | 0.9962933831848204, 7 | 0.9901570179499686, 8 | 0.9955691522918642, 9 | 0.9949185764417052, 10 | 0.9948671241290867, 11 | 0.9951038942672312, 12 | 0.9942043093033135, 13 | 0.9923509689979255, 14 | 0.99130690516904, 15 | 0.9868006301112473, 16 | 0.9852745439857244, 17 | 0.9817174589261413, 18 | 0.977879146579653, 19 | 0.9693511691875756, 20 | 0.959740437567234, 21 | 0.9505377076566219, 22 | 0.9496744400821626, 23 | 0.945977492723614, 24 | 0.9431825019419193, 25 | 0.9414051696658134, 26 | 0.9482622602954507, 27 | 0.9461887339130044, 28 | 0.9471860215999186, 29 | 0.944672588724643, 30 | 0.9415744030848145, 31 | 0.9331777216866612, 32 | 0.9259660355746746, 33 | 0.9233388286083937, 34 | 0.9436734234914184 35 | ], 36 | "_name_or_path": "facebook/opt-2.7b", 37 | "training": true, 38 | "use_optimized_weight_cache": true, 39 | "return_dict": true, 40 | "_remove_final_layer_norm": false, 41 | "activation_dropout": 0.0, 42 | "activation_function": "relu", 43 | "architectures": [ 44 | "OPTSkipConnectionForCausalLM" 45 | ], 46 | "attention_dropout": 0.0, 47 | "bos_token_id": 2, 48 | "do_layer_norm_before": true, 49 | "dropout": 0.1, 50 | "eos_token_id": 2, 51 | "intermediate_size": 10240, 52 | "hidden_size": 2560, 53 | "init_std": 0.02, 54 | "layerdrop": 0.0, 55 | "max_position_embeddings": 2048, 56 | "model_type": "opt-skip", 57 | "num_attention_heads": 32, 58 | "num_hidden_layers": 32, 59 | "pad_token_id": 1, 60 | "prefix": "", 61 | "torch_dtype": "float16", 62 | "transformers_version": "4.21.0.dev0", 63 | "use_cache": true, 64 | "vocab_size": 50272, 65 | "word_embed_proj_dim": 2560 66 | } -------------------------------------------------------------------------------- /configs/phi3_skip_causal_3.8b.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "microsoft/Phi-4-mini-instruct", 3 | "sparsity": 0.3, 4 | "architectures": [ 5 | "Phi3SkipConnectionForCausalLM" 6 | ], 7 | "attention_bias": false, 8 | "attention_dropout": 0.0, 9 | "auto_map": { 10 | "AutoTokenizer": "Xenova/gpt-4o" 11 | }, 12 | "bos_token_id": 199999, 13 | "embd_pdrop": 0.0, 14 | "eos_token_id": 199999, 15 | "full_attn_mod": 1, 16 | "hidden_act": "silu", 17 | "hidden_size": 3072, 18 | "initializer_range": 0.02, 19 | "intermediate_size": 8192, 20 | "interpolate_factor": 1, 21 | "lm_head_bias": false, 22 | "max_position_embeddings": 131072, 23 | "mlp_bias": false, 24 | "model_type": "phi3-skip", 25 | "num_attention_heads": 24, 26 | "num_hidden_layers": 32, 27 | "num_key_value_heads": 8, 28 | "original_max_position_embeddings": 4096, 29 | "pad_token_id": 199999, 30 | "partial_rotary_factor": 0.75, 31 | "resid_pdrop": 0.0, 32 | "rms_norm_eps": 1e-05, 33 | "rope_scaling": { 34 | "long_factor": [ 35 | 1, 36 | 1.118320672, 37 | 1.250641126, 38 | 1.398617824, 39 | 1.564103225, 40 | 1.74916897, 41 | 1.956131817, 42 | 2.187582649, 43 | 2.446418898, 44 | 2.735880826, 45 | 3.059592084, 46 | 3.421605075, 47 | 3.826451687, 48 | 4.279200023, 49 | 4.785517845, 50 | 5.351743533, 51 | 5.984965424, 52 | 6.693110555, 53 | 7.485043894, 54 | 8.370679318, 55 | 9.36110372, 56 | 10.4687158, 57 | 11.70738129, 58 | 13.09260651, 59 | 14.64173252, 60 | 16.37415215, 61 | 18.31155283, 62 | 20.47818807, 63 | 22.90118105, 64 | 25.61086418, 65 | 28.64115884, 66 | 32.03, 67 | 32.1, 68 | 32.13, 69 | 32.23, 70 | 32.6, 71 | 32.61, 72 | 32.64, 73 | 32.66, 74 | 32.7, 75 | 32.71, 76 | 32.93, 77 | 32.97, 78 | 33.28, 79 | 33.49, 80 | 33.5, 81 | 44.16, 82 | 47.77 83 | ], 84 | "short_factor": [ 85 | 1.0, 86 | 1.0, 87 | 1.0, 88 | 1.0, 89 | 1.0, 90 | 1.0, 91 | 1.0, 92 | 1.0, 93 | 1.0, 94 | 1.0, 95 | 1.0, 96 | 1.0, 97 | 1.0, 98 | 1.0, 99 | 1.0, 100 | 1.0, 101 | 1.0, 102 | 1.0, 103 | 1.0, 104 | 1.0, 105 | 1.0, 106 | 1.0, 107 | 1.0, 108 | 1.0, 109 | 1.0, 110 | 1.0, 111 | 1.0, 112 | 1.0, 113 | 1.0, 114 | 1.0, 115 | 1.0, 116 | 1.0, 117 | 1.0, 118 | 1.0, 119 | 1.0, 120 | 1.0, 121 | 1.0, 122 | 1.0, 123 | 1.0, 124 | 1.0, 125 | 1.0, 126 | 1.0, 127 | 1.0, 128 | 1.0, 129 | 1.0, 130 | 1.0, 131 | 1.0, 132 | 1.0 133 | ], 134 | "type": "longrope" 135 | }, 136 | "rope_theta": 10000.0, 137 | "sliding_window": 262144, 138 | "tie_word_embeddings": true, 139 | "torch_dtype": "bfloat16", 140 | "transformers_version": "4.45.0", 141 | "use_cache": true, 142 | "vocab_size": 200064 143 | } 144 | 145 | -------------------------------------------------------------------------------- /configs/qwen2_skip_causal_1.5b.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "Qwen/Qwen2-1.5B", 3 | "sparsity_method": "naive", 4 | "sparsities": [ 5 | 0.7590979598462582, 6 | 0.7837754879146814, 7 | 0.7750004627741873, 8 | 0.770095819607377, 9 | 0.7902259933762252, 10 | 0.8609298202209175, 11 | 0.8853354640305042, 12 | 0.8391136317513883, 13 | 0.8428834634833038, 14 | 0.8334567951969802, 15 | 0.8006975213065743, 16 | 0.8045121841132641, 17 | 0.8163408637046814, 18 | 0.7881555473431945, 19 | 0.7771737249568105, 20 | 0.7892530923709273, 21 | 0.7808592799119651, 22 | 0.8027198943309486, 23 | 0.8147787847556174, 24 | 0.8215602654963732, 25 | 0.8315094322897494, 26 | 0.883263947442174, 27 | 0.9112895792350173, 28 | 0.8935728948563337, 29 | 0.8899188376963139, 30 | 0.8447432252578437, 31 | 0.7769220205955207, 32 | 0.8743839487433434 33 | ], 34 | "architectures": [ 35 | "Qwen2SkipForCausalLM" 36 | ], 37 | "attention_dropout": 0.0, 38 | "bos_token_id": 151643, 39 | "eos_token_id": 151643, 40 | "hidden_act": "silu", 41 | "hidden_size": 1536, 42 | "initializer_range": 0.02, 43 | "intermediate_size": 8960, 44 | "max_position_embeddings": 131072, 45 | "max_window_layers": 28, 46 | "model_type": "qwen2-skip", 47 | "num_attention_heads": 12, 48 | "num_hidden_layers": 28, 49 | "num_key_value_heads": 2, 50 | "rms_norm_eps": 1e-06, 51 | "rope_theta": 1000000.0, 52 | "sliding_window": 131072, 53 | "tie_word_embeddings": true, 54 | "torch_dtype": "bfloat16", 55 | "transformers_version": "4.40.1", 56 | "use_cache": true, 57 | "use_sliding_window": false, 58 | "vocab_size": 151936 59 | } 60 | 61 | -------------------------------------------------------------------------------- /configs/relullama_skip_causal_7b.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "SparseLLM/ReluLLaMA-7B", 3 | "sparsity_method": "naive", 4 | "sparsities": [ 5 | 0.743192738853395, 6 | 0.7336088079027832, 7 | 0.6890966654755175, 8 | 0.7145767104811966, 9 | 0.73568778578192, 10 | 0.7505242507904768, 11 | 0.7502159993164241, 12 | 0.747797972522676, 13 | 0.7135628159157932, 14 | 0.7085652560926974, 15 | 0.6838056156411767, 16 | 0.6900686351582408, 17 | 0.6819221628829837, 18 | 0.6771378242410719, 19 | 0.6827241298742592, 20 | 0.6764037436805665, 21 | 0.6925274380482733, 22 | 0.7076996429823339, 23 | 0.7007281989790499, 24 | 0.6772957886569202, 25 | 0.6694492469541728, 26 | 0.6437577940523624, 27 | 0.6286926441825926, 28 | 0.611934173386544, 29 | 0.604402104858309, 30 | 0.607235555537045, 31 | 0.6056556575931609, 32 | 0.6122667123563588, 33 | 0.6011746157892048, 34 | 0.5901505807414651, 35 | 0.592667915392667, 36 | 0.5540782236494124 37 | ], 38 | "architectures": ["LlamaSkipConnectionForCausalLM"], 39 | "bos_token_id": 1, 40 | "eos_token_id": 2, 41 | "hidden_act": "relu", 42 | "hidden_size": 4096, 43 | "initializer_range": 0.02, 44 | "intermediate_size": 11008, 45 | "max_length": 4096, 46 | "max_position_embeddings": 2048, 47 | "model_type": "llama-skip", 48 | "num_attention_heads": 32, 49 | "num_hidden_layers": 32, 50 | "num_key_value_heads": 32, 51 | "pad_token_id": 0, 52 | "pretraining_tp": 1, 53 | "rms_norm_eps": 1e-05, 54 | "rope_scaling": null, 55 | "tie_word_embeddings": false, 56 | "torch_dtype": "float32", 57 | "transformers_version": "4.31.0", 58 | "use_cache": true, 59 | "vocab_size": 32000 60 | } -------------------------------------------------------------------------------- /configs/relullama_skip_causal_7b_predictor_training.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "SparseLLM/ReluLLaMA-7B", 3 | "training": true, 4 | "predictor_loss_type": "bce", 5 | "predictor_temperature": 1.0, 6 | "predictor_loss_alpha": 1.0, 7 | "predictor_loss_weight": 0.1, 8 | "predictor_confidence_penalty": 0.15, 9 | "predictor_focal_gamma": 2.0, 10 | "use_optimized_weight_cache": true, 11 | "return_dict": true, 12 | "sparsity": 0.3, 13 | "architectures": ["LlamaSkipConnectionForCausalLM"], 14 | "bos_token_id": 1, 15 | "eos_token_id": 2, 16 | "hidden_act": "relu", 17 | "hidden_size": 4096, 18 | "initializer_range": 0.02, 19 | "intermediate_size": 11008, 20 | "max_length": 4096, 21 | "max_position_embeddings": 2048, 22 | "model_type": "llama-skip", 23 | "num_attention_heads": 32, 24 | "num_hidden_layers": 32, 25 | "num_key_value_heads": 32, 26 | "pad_token_id": 0, 27 | "pretraining_tp": 1, 28 | "rms_norm_eps": 1e-05, 29 | "rope_scaling": null, 30 | "tie_word_embeddings": false, 31 | "torch_dtype": "float32", 32 | "transformers_version": "4.31.0", 33 | "use_cache": true, 34 | "vocab_size": 32000 35 | } -------------------------------------------------------------------------------- /downstream_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import logging 4 | 5 | import torch 6 | 7 | from transformers import AutoConfig, AutoModelForCausalLM 8 | from lm_eval import simple_evaluate 9 | from lm_eval.utils import make_table 10 | from lm_eval.models.huggingface import HFLM 11 | 12 | import src.models 13 | 14 | # Setup logging 15 | logging.basicConfig(level=logging.INFO) 16 | logger = logging.getLogger(__name__) 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description="Evaluate trained model on common LM datasets using LM Eval Harness.") 20 | parser.add_argument("--model_type", type=str, choices=["hf", "sparse"], default="hf") 21 | parser.add_argument("--model_name_or_config", type=str, required=True, 22 | help="Name or path of the base model (e.g., meta-llama/Llama-2-7b-hf)") 23 | parser.add_argument("--tasks", nargs='+', default=["hellaswag"], 24 | help="Tasks on which to evaluate") 25 | parser.add_argument("--batch_size", type=int, default=4, 26 | help="Batch size for processing") 27 | parser.add_argument("--device", type=str, default="auto", 28 | help="Device to use (auto, cpu, cuda)") 29 | parser.add_argument("--sp_dir", type=str, default="", 30 | help="Path to trained predictor dir for sparse model.") 31 | parser.add_argument("--lora_size", type=float, default=4.0, 32 | help="Size of lora predictors to use as percentage of total hidden size") 33 | parser.add_argument("--sp_layers", default="all", nargs='+', 34 | help="Which layers to use sparse predictors for") 35 | parser.add_argument("--sparsity_method", default="naive", choices=["naive", "topk", "statistical_topk"], 36 | help="Which method to use to determine active indices") 37 | parser.add_argument("--disable_weight_cache", action="store_true", 38 | help="Disable weight cache and compute sparse mlp manually") 39 | return parser.parse_args() 40 | 41 | 42 | def main(): 43 | args = parse_args() 44 | 45 | # Setup device 46 | if args.device == "auto": 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | else: 49 | device = torch.device(args.device) 50 | 51 | logger.info(f"Using device: {device}") 52 | 53 | # Load pretrained model 54 | logging.info("Loading pretrained model for evaluation...") 55 | 56 | if args.model_type == "hf": 57 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_config) 58 | if args.model_type == "sparse": 59 | config = AutoConfig.from_pretrained(args.model_name_or_config) 60 | config.sp_layers = "all" if "all" in args.sp_layers else [int(x) for x in args.sp_layers] 61 | config.lora_size = args.lora_size / 100.0 62 | config.sparsity_method = args.sparsity_method 63 | if args.disable_weight_cache: 64 | config.use_weight_cache = False 65 | model = AutoModelForCausalLM.from_pretrained(config._name_or_path, config=config) 66 | for layer_idx in model.get_decoder().sp_layers: 67 | layer = model.get_decoder().layers[layer_idx] 68 | layer_path = os.path.join(args.sp_dir, f"final_predictor_layer_{layer_idx}_lora_{args.lora_size}pct.pt") 69 | if not os.path.exists(layer_path): 70 | logger.error(f"Pretrained weights for sparse predictor at layer {layer_idx} do not exist.") 71 | return 72 | pretrained_dict = torch.load(layer_path) 73 | layer.mlp_lora_proj.load_state_dict(pretrained_dict) 74 | model.tie_weights() 75 | model.to(device) 76 | model.reset_cache() 77 | 78 | wrapped_model = HFLM( 79 | pretrained=model, 80 | backend="causal", 81 | batch_size=args.batch_size, 82 | device=device 83 | ) 84 | 85 | logging.info("Beginning evaluation...") 86 | results = simple_evaluate( 87 | wrapped_model, 88 | tasks=args.tasks, 89 | batch_size=args.batch_size, 90 | device=device 91 | ) 92 | 93 | if results is not None: 94 | print(make_table(results)) 95 | if "groups" in results: 96 | print(make_table(results, "groups")) 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /measure_gt_sparsity.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Calculate ground-truth sparsities for various base models on a given dataset. 4 | 5 | This script takes a list of HuggingFace-compatible models and runs each model 6 | on a number of samples from a given dataset. Activation statistics are captured 7 | from the models' forward passes, and used to determine the average ground-truth 8 | sparsity of each layer for each model. 9 | 10 | This data can then be plotted or saved in a json file to be used as thresholds 11 | for the topk or statistical-topk sparsity methods using trained predictors. 12 | 13 | 14 | Usage examples: 15 | # Capture ground truth sparsity values for a particular model or models 16 | python measure_gt_sparsity.py \ 17 | --models meta-llama/Llama-3.2-3B-Instruct \ 18 | --num_samples 2048 \ 19 | --max_length 512 \ 20 | --output_dir sparsities \ 21 | --device cuda 22 | 23 | # Generate a plot of ground truth sparsity values by layer and model 24 | python measure_gt_sparsity.py \ 25 | --models meta-llama/Llama-3.2-3B-Instruct Qwen/Qwen2-1.5B google/gemma-3n-E2B \ 26 | --num_samples 2048 \ 27 | --max_length 512 \ 28 | --output_dir sparsities \ 29 | --device cuda \ 30 | --make_plots 31 | """ 32 | 33 | 34 | 35 | import argparse 36 | from collections import defaultdict 37 | import json 38 | import logging 39 | import os 40 | from typing import Dict 41 | 42 | from datasets import load_dataset 43 | import torch 44 | from torch.utils.data import DataLoader, Dataset 45 | from tqdm import tqdm 46 | from transformers import AutoModelForCausalLM, AutoTokenizer 47 | from transformers.trainer_utils import set_seed 48 | 49 | import matplotlib.pyplot as plt 50 | from src.activation_capture import Hook, capture_model 51 | 52 | # Setup logging 53 | logging.basicConfig(level=logging.INFO) 54 | logger = logging.getLogger(__name__) 55 | 56 | 57 | class ContextualSparsityAnalyzer: 58 | """Analyzer for measuring contextual sparsity patterns in LLaMA models.""" 59 | 60 | def __init__(self, model, tokenizer, device): 61 | self.model = model 62 | self.tokenizer = tokenizer 63 | self.device = device 64 | 65 | model.activation_capture = capture_model(model) 66 | model.activation_capture.register_hooks(hooks=[Hook.ACT]) 67 | self.num_layers = len(self.model.activation_capture.get_layers()) 68 | 69 | self.reset_buffers() 70 | 71 | def reset_buffers(self): 72 | self.mlp_sparsity = defaultdict(list) 73 | self.num_seqs = 0 74 | 75 | def process_batch(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): 76 | batch_size = input_ids.size(0) 77 | 78 | # Clear previous captures and GPU cache 79 | self.model.activation_capture.clear_captures() 80 | if self.device.type == "cuda": 81 | torch.cuda.empty_cache() 82 | 83 | # Forward pass 84 | with torch.no_grad(): 85 | _ = self.model(input_ids=input_ids, attention_mask=attention_mask) 86 | 87 | # Compute sparsity 88 | for layer_idx in range(self.num_layers): 89 | sparsity_masks = ( 90 | self.model.activation_capture.mlp_activations[Hook.ACT][layer_idx] <= 0 91 | ) 92 | 93 | # Naive sparsity computation 94 | self.mlp_sparsity[layer_idx].append( 95 | sparsity_masks.float().mean().item() 96 | ) 97 | 98 | # Level of sparsity after union over batch dim 99 | # union_sparsity_mask = sparsity_masks.any(dim=0) 100 | # self.union_sparsity[batch_size][layer_idx].append(union_sparsity_mask.float().mean().item()) 101 | 102 | # TODO: Add HNSW sparsity computation for both attn heads and mlp neurons 103 | # TODO: Compute union sparsity over multiple different batch sizes 104 | 105 | # Clear GPU tensors from capture to free memory 106 | self.model.activation_capture.clear_captures() 107 | if self.device.type == "cuda": 108 | torch.cuda.empty_cache() 109 | 110 | self.num_seqs += batch_size 111 | 112 | 113 | def analyze_sparsity(args, model_name, device): 114 | # Load model and tokenizer 115 | logger.info(f"Loading model: {model_name}") 116 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 117 | if tokenizer.pad_token is None: 118 | tokenizer.pad_token = tokenizer.eos_token 119 | 120 | model = AutoModelForCausalLM.from_pretrained( 121 | model_name, 122 | torch_dtype=torch.float16 if device.type == "cuda" else torch.float32, 123 | device_map="auto" if device.type == "cuda" else None, 124 | trust_remote_code=True, 125 | ) 126 | 127 | if device.type != "cuda": 128 | model = model.to(device) 129 | 130 | # Load C4 dataset 131 | dataset = C4Dataset(tokenizer, args.max_length, args.num_samples) 132 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False) 133 | 134 | analyzer = ContextualSparsityAnalyzer(model, tokenizer, device) 135 | try: 136 | # Process dataset 137 | logger.info("Starting contextual sparsity analysis...") 138 | 139 | for batch_idx, batch in enumerate(tqdm(dataloader, desc="Analyzing sequences")): 140 | input_ids = batch["input_ids"].to(device) 141 | attention_mask = batch["attention_mask"].to(device) 142 | analyzer.process_batch(input_ids, attention_mask) 143 | 144 | # Log progress 145 | if (batch_idx + 1) % 100 == 0: 146 | logger.info(f"Processed {batch_idx + 1}/{len(dataloader)} sequences") 147 | 148 | analyzer.mlp_sparsity = [ 149 | sum(analyzer.mlp_sparsity[layer_idx]) / len(analyzer.mlp_sparsity[layer_idx]) 150 | for layer_idx in range(len(analyzer.mlp_sparsity)) 151 | ] 152 | finally: 153 | analyzer.model.activation_capture.remove_hooks() 154 | return analyzer.mlp_sparsity 155 | 156 | 157 | def plot_sparsities(sparsities, output_dir=None): 158 | plt.figure(figsize=(10, 6)) 159 | for model, model_sparsities in sparsities.items(): 160 | model_name = model.split("/")[1].capitalize() 161 | plt.plot([i*100/len(model_sparsities) for i in range(len(model_sparsities))], [x*100 for x in model_sparsities], label=model_name) 162 | plt.xlabel("Layer Index Percentage (layer_idx/num_layers)") 163 | plt.ylabel(f"% of Neurons Inactive") 164 | plt.title(f"ACtivation Sparsity By Layer") 165 | plt.legend() 166 | plt.minorticks_on() 167 | if output_dir: 168 | plt.savefig( 169 | os.path.join(output_dir, f"sparsity_analysis.png"), 170 | dpi=300, 171 | bbox_inches="tight", 172 | ) 173 | 174 | 175 | class C4Dataset(Dataset): 176 | """C4 dataset for contextual sparsity analysis.""" 177 | 178 | def __init__(self, tokenizer, max_length: int = 512, num_samples: int = 1000): 179 | self.tokenizer = tokenizer 180 | self.max_length = max_length 181 | 182 | # Load C4 dataset 183 | logger.info("Loading C4 dataset...") 184 | dataset = load_dataset( 185 | "allenai/c4", "realnewslike", split="train", streaming=True 186 | ) 187 | 188 | # Process samples 189 | self.samples = [] 190 | for i, sample in enumerate(dataset): 191 | if i >= num_samples: 192 | break 193 | 194 | text = sample["text"] 195 | if len(text.strip()) > 50: # Filter out very short texts 196 | encoding = tokenizer( 197 | text, 198 | truncation=True, 199 | padding="max_length", 200 | max_length=max_length, 201 | return_tensors="pt", 202 | ) 203 | 204 | if ( 205 | encoding["input_ids"].shape[1] > 10 206 | ): # Ensure minimum sequence length 207 | self.samples.append( 208 | { 209 | "input_ids": encoding["input_ids"].squeeze(), 210 | "attention_mask": encoding["attention_mask"].squeeze(), 211 | "text": text[:200] + "..." if len(text) > 200 else text, 212 | } 213 | ) 214 | 215 | logger.info(f"Loaded {len(self.samples)} C4 samples") 216 | 217 | def __len__(self): 218 | return len(self.samples) 219 | 220 | def __getitem__(self, idx): 221 | return self.samples[idx] 222 | 223 | 224 | def main(): 225 | parser = argparse.ArgumentParser( 226 | description="Measure contextual sparsity in LLaMA models" 227 | ) 228 | parser.add_argument( 229 | "--models", 230 | type=str, 231 | nargs="+", 232 | default=[ 233 | "meta-llama/Llama-3.2-3B-Instruct", 234 | "Qwen/Qwen2-1.5B", 235 | ], 236 | help="HuggingFace model names or paths", 237 | ) 238 | parser.add_argument( 239 | "--output_dir", type=str, required=True, help="Output directory for results" 240 | ) 241 | parser.add_argument( 242 | "--num_samples", type=int, default=1000, help="Number of C4 samples to analyze" 243 | ) 244 | parser.add_argument( 245 | "--max_length", type=int, default=512, help="Maximum sequence length" 246 | ) 247 | parser.add_argument( 248 | "--batch_size", 249 | type=int, 250 | default=1, 251 | help="Batch size (recommend 1 for token-by-token analysis)", 252 | ) 253 | parser.add_argument( 254 | "--device", type=str, default="auto", help="Device to use (auto, cpu, cuda)" 255 | ) 256 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 257 | parser.add_argument( 258 | "--make_plots", action="store_true", help="Generate and save analysis plots" 259 | ) 260 | 261 | args = parser.parse_args() 262 | 263 | # Set seed 264 | set_seed(args.seed) 265 | 266 | # Setup device 267 | if args.device == "auto": 268 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 269 | else: 270 | device = torch.device(args.device) 271 | 272 | logger.info(f"Using device: {device}") 273 | 274 | # Setup output directory 275 | os.makedirs(args.output_dir, exist_ok=True) 276 | 277 | outs = defaultdict(dict) 278 | for model in args.models: 279 | outs[model] = analyze_sparsity(args, model, device) 280 | json.dump(outs, open(os.path.join(args.output_dir, "sparsity.json"), "w")) 281 | 282 | if args.make_plots: 283 | plot_sparsities(outs) 284 | 285 | 286 | if __name__ == "__main__": 287 | main() 288 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64", "wheel", "torch"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "sparse_transformers" 7 | version = "0.0.1" 8 | description = "Sparse Inference for transformers" 9 | authors = [ 10 | {name = "NimbleEdge"} 11 | ] 12 | requires-python = ">=3.7" 13 | 14 | [tool.setuptools] 15 | packages = ["sparse_transformers"] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Core ML/AI packages 2 | # conda install pytorch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 pytorch-cuda=12.4 -c pytorch -c nvidia 3 | transformers==4.53.0 4 | numpy 5 | psutil 6 | optimum 7 | attrs 8 | scikit-learn 9 | accelerate 10 | datasets 11 | sentencepiece 12 | protobuf 13 | wandb 14 | ninja 15 | timm 16 | pillow 17 | lm-eval 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 3 | import os 4 | import torch 5 | from pathlib import Path 6 | import shutil 7 | import sys 8 | import warnings 9 | 10 | 11 | # Parse custom build arguments 12 | def parse_build_args(): 13 | """Parse custom build arguments for CPU/GPU selection. 14 | 15 | Usage: 16 | python setup.py develop cpu # Force CPU-only build 17 | python setup.py develop gpu # Force GPU build (fallback to CPU if not available) 18 | python setup.py develop # Auto-detect (prefer GPU if available) 19 | """ 20 | build_mode = 'auto' # Default to auto-detect 21 | 22 | # Check for help request 23 | if 'help' in sys.argv or '--help' in sys.argv: 24 | print("\nSparse Transformers Build Options:") 25 | print(" python setup.py develop cpu # Force CPU-only build") 26 | print(" python setup.py develop gpu # Force GPU build") 27 | print(" python setup.py develop # Auto-detect (prefer GPU)") 28 | print() 29 | 30 | # Check for our custom arguments 31 | if 'cpu' in sys.argv: 32 | build_mode = 'cpu' 33 | sys.argv.remove('cpu') 34 | print("Forced CPU-only build mode") 35 | elif 'gpu' in sys.argv: 36 | build_mode = 'gpu' 37 | sys.argv.remove('gpu') 38 | print("Forced GPU build mode") 39 | else: 40 | print("Auto-detecting build mode (default: GPU if available)") 41 | 42 | return build_mode 43 | 44 | 45 | # Check PyTorch C++ ABI compatibility 46 | def get_pytorch_abi_flag(): 47 | """Get the correct C++ ABI flag to match PyTorch compilation.""" 48 | return f'-D_GLIBCXX_USE_CXX11_ABI={int(torch._C._GLIBCXX_USE_CXX11_ABI)}' 49 | 50 | 51 | # Get PyTorch ABI flag 52 | pytorch_abi_flag = get_pytorch_abi_flag() 53 | print(f"Using PyTorch C++ ABI flag: {pytorch_abi_flag}") 54 | 55 | # Parse build mode from command line 56 | build_mode = parse_build_args() 57 | 58 | # Create build directory if it doesn't exist 59 | build_dir = Path(__file__).parent / 'build' 60 | if build_dir.exists(): 61 | shutil.rmtree(build_dir) 62 | build_dir.mkdir(parents=True) 63 | (build_dir / 'lib').mkdir(exist_ok=True) 64 | 65 | # Set environment variables to control build output 66 | os.environ['TORCH_BUILD_DIR'] = str(build_dir) 67 | os.environ['BUILD_LIB'] = str(build_dir / 'lib') 68 | os.environ['BUILD_TEMP'] = str(build_dir / 'temp') 69 | 70 | # Get CUDA compute capability if GPU is available 71 | arch_flags = [] 72 | if torch.cuda.is_available(): 73 | try: 74 | arch_list = [] 75 | for i in range(torch.cuda.device_count()): 76 | arch_list.append(torch.cuda.get_device_capability(i)) 77 | arch_list = sorted(list(set(arch_list))) 78 | arch_flags = [ 79 | f"-gencode=arch=compute_{arch[0]}{arch[1]},code=sm_{arch[0]}{arch[1]}" 80 | for arch in arch_list 81 | ] 82 | print(f"CUDA architectures detected: {arch_list}") 83 | except Exception as e: 84 | warnings.warn(f"Error detecting CUDA architecture: {e}") 85 | # Use a common architecture as fallback 86 | arch_flags = ['-gencode=arch=compute_86,code=sm_86'] 87 | 88 | # Common optimization flags (compatible with both old and new ABI) 89 | common_compile_args = [ 90 | '-O3', # Maximum optimization 91 | '-fopenmp', # OpenMP support 92 | '-flto', # Link-time optimization 93 | '-funroll-loops', # Unroll loops 94 | '-fno-math-errno', # Assume math functions never set errno 95 | '-fno-trapping-math', # Assume FP ops don't generate traps 96 | '-mtune=native', # Tune code for local CPU 97 | pytorch_abi_flag, # Critical: Match PyTorch's C++ ABI 98 | '-DTORCH_API_INCLUDE_EXTENSION_H', # PyTorch extension header compatibility 99 | ] 100 | 101 | # Try to detect if we can use advanced CPU optimizations safely 102 | try: 103 | import platform 104 | 105 | if platform.machine() in ['x86_64', 'AMD64']: 106 | advanced_cpu_flags = [ 107 | '-march=native', # Optimize for local CPU architecture 108 | '-mtune=native', # Tune code for local CPU 109 | '-mavx2', # Enable AVX2 instructions if available 110 | '-mfma', # Enable FMA instructions if available 111 | ] 112 | else: 113 | advanced_cpu_flags = [] 114 | except: 115 | advanced_cpu_flags = [] 116 | 117 | # CPU-specific optimization flags 118 | cpu_compile_args = ( 119 | common_compile_args 120 | + advanced_cpu_flags 121 | + [ 122 | '-flto', # Link-time optimization 123 | '-funroll-loops', # Unroll loops 124 | '-fno-math-errno', # Assume math functions never set errno 125 | '-fno-trapping-math', # Assume FP ops don't generate traps 126 | '-fno-plt', # Improve indirect call performance 127 | '-fuse-linker-plugin', # Enable LTO plugin 128 | '-fomit-frame-pointer', # Remove frame pointers 129 | '-fno-stack-protector', # Disable stack protector 130 | '-fvisibility=hidden', # Hide all symbols by default 131 | '-fdata-sections', # Place each data item into its own section 132 | '-ffunction-sections', # Place each function into its own section 133 | '-fvisibility=default', 134 | ] 135 | ) 136 | 137 | # CUDA-specific optimization flags (ensure C++17 compatibility and ABI matching) 138 | cuda_compile_args = ( 139 | ['-O3', '--use_fast_math'] 140 | + arch_flags 141 | + [ 142 | '--compiler-options', 143 | "'-fPIC'", 144 | '--compiler-options', 145 | "'-O3'", 146 | '-std=c++17', # Force C++17 for compatibility 147 | '--compiler-options', 148 | "'-fvisibility=default'", 149 | ] 150 | ) 151 | 152 | # Add advanced CPU flags to CUDA compilation if available 153 | if advanced_cpu_flags: 154 | for flag in ['-march=native', '-ffast-math']: 155 | cuda_compile_args.extend(['--compiler-options', f"'{flag}'"]) 156 | 157 | # Link flags 158 | extra_link_args = [ 159 | '-fopenmp', 160 | '-flto', # Link-time optimization 161 | '-fuse-linker-plugin', # Enable LTO plugin 162 | '-Wl,--as-needed', # Only link needed libraries 163 | '-Wl,-O3', # Linker optimizations 164 | '-Wl,--strip-all', # Strip all symbols 165 | '-Wl,--gc-sections', # Remove unused sections 166 | '-Wl,--exclude-libs,ALL', # Don't export any symbols from libraries 167 | ] 168 | 169 | 170 | # Get CUDA include paths 171 | def get_cuda_include_dirs(): 172 | cuda_home = os.getenv('CUDA_HOME', '/usr/local/cuda') 173 | if not os.path.exists(cuda_home): 174 | cuda_home = os.getenv('CUDA_PATH') # Windows 175 | 176 | if cuda_home is None: 177 | # Try common CUDA locations 178 | for path in ['/usr/local/cuda', '/opt/cuda', '/usr/cuda']: 179 | if os.path.exists(path): 180 | cuda_home = path 181 | break 182 | 183 | if cuda_home is None: 184 | warnings.warn('CUDA installation not found. CUDA extensions will not be built.') 185 | return [] 186 | 187 | return [ 188 | os.path.join(cuda_home, 'include'), 189 | os.path.join(cuda_home, 'samples', 'common', 'inc'), 190 | ] 191 | 192 | 193 | # Base extension configuration 194 | base_include_dirs = [ 195 | os.path.dirname(torch.__file__) + '/include', 196 | os.path.dirname(torch.__file__) + '/include/torch/csrc/api/include', 197 | os.path.dirname(torch.__file__) + '/include/ATen', 198 | os.path.dirname(torch.__file__) + '/include/c10', 199 | ] 200 | 201 | # Define extensions 202 | ext_modules = [] 203 | 204 | cpp_source = 'sparse_transformers/csrc/sparse_mlp_op.cpp' 205 | cuda_source = 'sparse_transformers/csrc/sparse_mlp_cuda.cu' 206 | 207 | if not os.path.exists(cpp_source): 208 | warnings.warn(f"C++ source file not found: {cpp_source}") 209 | raise FileNotFoundError(f"Missing source file: {cpp_source}") 210 | 211 | # Determine if we should build CUDA extension based on build mode 212 | should_build_cuda = False 213 | 214 | if build_mode == 'cpu': 215 | print("CPU-only build requested - skipping CUDA") 216 | should_build_cuda = False 217 | elif build_mode == 'gpu': 218 | print("GPU build requested") 219 | if not torch.cuda.is_available(): 220 | print("WARNING: GPU build requested but PyTorch CUDA not available") 221 | print(" Falling back to CPU-only build") 222 | should_build_cuda = False 223 | elif not os.path.exists(cuda_source): 224 | print("WARNING: GPU build requested but CUDA source file not found") 225 | print(" Falling back to CPU-only build") 226 | should_build_cuda = False 227 | else: 228 | should_build_cuda = True 229 | else: # auto mode 230 | # Default behavior: prefer GPU if available, otherwise CPU 231 | if torch.cuda.is_available() and os.path.exists(cuda_source): 232 | print("Auto-detected: Building GPU extension (CUDA available)") 233 | should_build_cuda = True 234 | else: 235 | print("Auto-detected: Building CPU-only extension (CUDA not available)") 236 | should_build_cuda = False 237 | 238 | if should_build_cuda: 239 | print("Building CUDA extension...") 240 | cuda_include_dirs = get_cuda_include_dirs() 241 | if cuda_include_dirs: 242 | base_include_dirs.extend(cuda_include_dirs) 243 | extension = CUDAExtension( 244 | name='sparse_transformers.sparse_transformers', 245 | sources=[cpp_source, cuda_source], 246 | include_dirs=base_include_dirs, 247 | extra_compile_args={'cxx': cpu_compile_args, 'nvcc': cuda_compile_args}, 248 | extra_link_args=extra_link_args, 249 | libraries=['gomp', 'cudart'], 250 | library_dirs=[str(build_dir / 'lib')], 251 | define_macros=[('WITH_CUDA', None)], 252 | ) 253 | else: 254 | print( 255 | "CUDA include directories not found, falling back to CPU-only extension..." 256 | ) 257 | should_build_cuda = False 258 | 259 | if not should_build_cuda: 260 | print("Building CPU-only extension...") 261 | extension = CppExtension( 262 | name='sparse_transformers.sparse_transformers', 263 | sources=[cpp_source], 264 | extra_compile_args=cpu_compile_args, 265 | extra_link_args=extra_link_args, 266 | library_dirs=[str(build_dir / 'lib')], 267 | include_dirs=base_include_dirs, 268 | libraries=['gomp'], 269 | define_macros=[('CPU_ONLY', None)], 270 | ) 271 | 272 | ext_modules.append(extension) 273 | build_type = "CUDA" if should_build_cuda else "CPU-only" 274 | print(f"Extension configured successfully: {extension.name} ({build_type})") 275 | 276 | 277 | # Custom build extension to handle clean builds and ABI compatibility 278 | class CustomBuildExtension(BuildExtension): 279 | def get_ext_filename(self, ext_name): 280 | # Force output to build directory 281 | filename = super().get_ext_filename(ext_name) 282 | return str(build_dir / 'lib' / os.path.basename(filename)) 283 | 284 | def get_ext_fullpath(self, ext_name): 285 | # Override to ensure extension is built in our build directory 286 | filename = self.get_ext_filename(ext_name) 287 | return str(build_dir / 'lib' / filename) 288 | 289 | def build_extensions(self): 290 | # Disable parallel build for better error reporting and CUDA compatibility 291 | if self.parallel: 292 | self.parallel = False 293 | 294 | # Print compilation info for debugging 295 | print(f"Building extensions with PyTorch {torch.__version__}") 296 | print(f"PyTorch C++ ABI: {pytorch_abi_flag}") 297 | super().build_extensions() 298 | print("C++ extension built successfully!") 299 | 300 | 301 | # Read requirements from requirements.txt 302 | def read_requirements(): 303 | requirements_path = Path(__file__).parent / 'requirements.txt' 304 | if requirements_path.exists(): 305 | with open(requirements_path, 'r') as f: 306 | requirements = [] 307 | for line in f: 308 | line = line.strip() 309 | if line and not line.startswith('#'): 310 | requirements.append(line) 311 | return requirements 312 | return [] 313 | 314 | 315 | setup( 316 | name='sparse_transformers', 317 | version='0.0.1', 318 | description='Sparse Inferencing for transformer based LLMs', 319 | packages=find_packages(), 320 | ext_modules=ext_modules, 321 | cmdclass={ 322 | 'build_ext': CustomBuildExtension.with_options(no_python_abi_suffix=True), 323 | }, 324 | install_requires=read_requirements(), 325 | python_requires='>=3.8', 326 | include_package_data=True, 327 | zip_safe=False, # Required for C++ extensions 328 | ) 329 | -------------------------------------------------------------------------------- /sparse_transformers/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.1 FATAL_ERROR) 2 | project(sparse_mlp) 3 | 4 | find_package(Torch REQUIRED) 5 | find_package(OpenMP REQUIRED) 6 | 7 | # Define our library target 8 | add_library(sparse_mlp SHARED 9 | csrc/sparse_mlp_op.cpp 10 | ) 11 | 12 | # Enable C++17 13 | target_compile_features(sparse_mlp PRIVATE cxx_std_17) 14 | 15 | # Add OpenMP flags 16 | target_compile_options(sparse_mlp PRIVATE ${OpenMP_CXX_FLAGS}) 17 | 18 | # Include directories 19 | target_include_directories(sparse_mlp PRIVATE 20 | ${TORCH_INCLUDE_DIRS} 21 | ${CMAKE_CURRENT_SOURCE_DIR}/csrc 22 | ) 23 | 24 | # Link against LibTorch and OpenMP 25 | target_link_libraries(sparse_mlp PRIVATE 26 | ${TORCH_LIBRARIES} 27 | OpenMP::OpenMP_CXX 28 | ) 29 | 30 | # Set output directory 31 | set_target_properties(sparse_mlp PROPERTIES 32 | LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/lib" 33 | RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/bin" 34 | ) 35 | 36 | # Add optimization flags 37 | target_compile_options(sparse_mlp PRIVATE 38 | -O3 39 | -ffast-math 40 | -march=native 41 | ) -------------------------------------------------------------------------------- /sparse_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | # Configure CPU threads and threading 5 | num_threads = os.cpu_count() 6 | print(f"Configuring for {num_threads} CPU threads") 7 | os.environ['OMP_NUM_THREADS'] = str(num_threads) 8 | os.environ['MKL_NUM_THREADS'] = str(num_threads) 9 | os.environ['OPENBLAS_NUM_THREADS'] = str(num_threads) 10 | os.environ['VECLIB_MAXIMUM_THREADS'] = str(num_threads) 11 | os.environ['NUMEXPR_NUM_THREADS'] = str(num_threads) 12 | torch.set_num_threads(num_threads) 13 | torch.set_num_interop_threads(num_threads) 14 | os.environ['MAX_JOBS'] = str(num_threads) 15 | 16 | # Enable TorchScript optimizations 17 | torch.jit.enable_onednn_fusion(True) 18 | torch._C._jit_override_can_fuse_on_cpu(True) 19 | torch._C._jit_override_can_fuse_on_gpu(True) 20 | torch._C._jit_set_texpr_fuser_enabled(True) 21 | torch._C._jit_set_profiling_executor(True) 22 | torch._C._jit_set_profiling_mode(True) 23 | 24 | torch.classes.load_library(os.path.join(os.path.dirname(__file__), "sparse_transformers.so")) 25 | 26 | # Only define these if the extension loaded successfully 27 | #sparse_mlp_forward = torch.ops.sparse_mlp.forward 28 | WeightCache = torch.classes.sparse_mlp.WeightCache 29 | #approx_topk_threshold = torch.ops.sparse_mlp.approx_topk_threshold 30 | __all__ = [ 31 | # 'sparse_mlp_forward', 32 | 'WeightCache', 33 | # 'approx_topk_threshold' 34 | ] -------------------------------------------------------------------------------- /sparse_transformers/csrc/.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: Google 4 | AccessModifierOffset: -1 5 | AlignAfterOpenBracket: Align 6 | AlignArrayOfStructures: None 7 | AlignConsecutiveAssignments: 8 | Enabled: false 9 | AcrossEmptyLines: false 10 | AcrossComments: false 11 | AlignCompound: false 12 | PadOperators: true 13 | AlignConsecutiveBitFields: 14 | Enabled: false 15 | AcrossEmptyLines: false 16 | AcrossComments: false 17 | AlignCompound: false 18 | PadOperators: false 19 | AlignConsecutiveDeclarations: 20 | Enabled: false 21 | AcrossEmptyLines: false 22 | AcrossComments: false 23 | AlignCompound: false 24 | PadOperators: false 25 | AlignConsecutiveMacros: 26 | Enabled: false 27 | AcrossEmptyLines: false 28 | AcrossComments: false 29 | AlignCompound: false 30 | PadOperators: false 31 | AlignEscapedNewlines: Left 32 | AlignOperands: Align 33 | AlignTrailingComments: true 34 | AllowAllArgumentsOnNextLine: true 35 | AllowAllParametersOfDeclarationOnNextLine: true 36 | AllowShortEnumsOnASingleLine: true 37 | AllowShortBlocksOnASingleLine: Never 38 | AllowShortCaseLabelsOnASingleLine: false 39 | AllowShortFunctionsOnASingleLine: All 40 | AllowShortLambdasOnASingleLine: All 41 | AllowShortIfStatementsOnASingleLine: WithoutElse 42 | AllowShortLoopsOnASingleLine: true 43 | AlwaysBreakAfterDefinitionReturnType: None 44 | AlwaysBreakAfterReturnType: None 45 | AlwaysBreakBeforeMultilineStrings: true 46 | AlwaysBreakTemplateDeclarations: Yes 47 | AttributeMacros: 48 | - __capability 49 | BinPackArguments: true 50 | BinPackParameters: true 51 | BraceWrapping: 52 | AfterCaseLabel: false 53 | AfterClass: false 54 | AfterControlStatement: Never 55 | AfterEnum: false 56 | AfterFunction: false 57 | AfterNamespace: false 58 | AfterObjCDeclaration: false 59 | AfterStruct: false 60 | AfterUnion: false 61 | AfterExternBlock: false 62 | BeforeCatch: false 63 | BeforeElse: false 64 | BeforeLambdaBody: false 65 | BeforeWhile: false 66 | IndentBraces: false 67 | SplitEmptyFunction: true 68 | SplitEmptyRecord: true 69 | SplitEmptyNamespace: true 70 | BreakBeforeBinaryOperators: None 71 | BreakBeforeConceptDeclarations: Always 72 | BreakBeforeBraces: Attach 73 | BreakBeforeInheritanceComma: false 74 | BreakInheritanceList: BeforeColon 75 | BreakBeforeTernaryOperators: true 76 | BreakConstructorInitializersBeforeComma: false 77 | BreakConstructorInitializers: BeforeColon 78 | BreakAfterJavaFieldAnnotations: false 79 | BreakStringLiterals: true 80 | ColumnLimit: 100 81 | CommentPragmas: '^ IWYU pragma:' 82 | QualifierAlignment: Leave 83 | CompactNamespaces: false 84 | ConstructorInitializerIndentWidth: 4 85 | ContinuationIndentWidth: 4 86 | Cpp11BracedListStyle: true 87 | DeriveLineEnding: true 88 | DerivePointerAlignment: true 89 | DisableFormat: false 90 | EmptyLineAfterAccessModifier: Never 91 | EmptyLineBeforeAccessModifier: LogicalBlock 92 | ExperimentalAutoDetectBinPacking: false 93 | PackConstructorInitializers: NextLine 94 | BasedOnStyle: '' 95 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 96 | AllowAllConstructorInitializersOnNextLine: true 97 | FixNamespaceComments: true 98 | ForEachMacros: 99 | - foreach 100 | - Q_FOREACH 101 | - BOOST_FOREACH 102 | IfMacros: 103 | - KJ_IF_MAYBE 104 | IncludeBlocks: Regroup 105 | IncludeCategories: 106 | - Regex: '^' 107 | Priority: 2 108 | SortPriority: 0 109 | CaseSensitive: false 110 | - Regex: '^<.*\.h>' 111 | Priority: 1 112 | SortPriority: 0 113 | CaseSensitive: false 114 | - Regex: '^<.*' 115 | Priority: 2 116 | SortPriority: 0 117 | CaseSensitive: false 118 | - Regex: '.*' 119 | Priority: 3 120 | SortPriority: 0 121 | CaseSensitive: false 122 | IncludeIsMainRegex: '([-_](test|unittest))?$' 123 | IncludeIsMainSourceRegex: '' 124 | IndentAccessModifiers: false 125 | IndentCaseLabels: true 126 | IndentCaseBlocks: false 127 | IndentGotoLabels: true 128 | IndentPPDirectives: None 129 | IndentExternBlock: AfterExternBlock 130 | IndentRequiresClause: true 131 | IndentWidth: 2 132 | IndentWrappedFunctionNames: false 133 | InsertBraces: false 134 | InsertTrailingCommas: None 135 | JavaScriptQuotes: Leave 136 | JavaScriptWrapImports: true 137 | KeepEmptyLinesAtTheStartOfBlocks: false 138 | LambdaBodyIndentation: Signature 139 | MacroBlockBegin: '' 140 | MacroBlockEnd: '' 141 | MaxEmptyLinesToKeep: 1 142 | NamespaceIndentation: None 143 | ObjCBinPackProtocolList: Never 144 | ObjCBlockIndentWidth: 2 145 | ObjCBreakBeforeNestedBlockParam: true 146 | ObjCSpaceAfterProperty: false 147 | ObjCSpaceBeforeProtocolList: true 148 | PenaltyBreakAssignment: 2 149 | PenaltyBreakBeforeFirstCallParameter: 1 150 | PenaltyBreakComment: 300 151 | PenaltyBreakFirstLessLess: 120 152 | PenaltyBreakOpenParenthesis: 0 153 | PenaltyBreakString: 1000 154 | PenaltyBreakTemplateDeclaration: 10 155 | PenaltyExcessCharacter: 1000000 156 | PenaltyReturnTypeOnItsOwnLine: 200 157 | PenaltyIndentedWhitespace: 0 158 | PointerAlignment: Left 159 | PPIndentWidth: -1 160 | RawStringFormats: 161 | - Language: Cpp 162 | Delimiters: 163 | - cc 164 | - CC 165 | - cpp 166 | - Cpp 167 | - CPP 168 | - 'c++' 169 | - 'C++' 170 | CanonicalDelimiter: '' 171 | BasedOnStyle: google 172 | - Language: TextProto 173 | Delimiters: 174 | - pb 175 | - PB 176 | - proto 177 | - PROTO 178 | EnclosingFunctions: 179 | - EqualsProto 180 | - EquivToProto 181 | - PARSE_PARTIAL_TEXT_PROTO 182 | - PARSE_TEST_PROTO 183 | - PARSE_TEXT_PROTO 184 | - ParseTextOrDie 185 | - ParseTextProtoOrDie 186 | - ParseTestProto 187 | - ParsePartialTestProto 188 | CanonicalDelimiter: pb 189 | BasedOnStyle: google 190 | ReferenceAlignment: Pointer 191 | ReflowComments: true 192 | RemoveBracesLLVM: false 193 | RequiresClausePosition: OwnLine 194 | SeparateDefinitionBlocks: Always 195 | ShortNamespaceLines: 1 196 | SortIncludes: CaseSensitive 197 | SortJavaStaticImport: Before 198 | SortUsingDeclarations: true 199 | SpaceAfterCStyleCast: false 200 | SpaceAfterLogicalNot: false 201 | SpaceAfterTemplateKeyword: true 202 | SpaceBeforeAssignmentOperators: true 203 | SpaceBeforeCaseColon: false 204 | SpaceBeforeCpp11BracedList: false 205 | SpaceBeforeCtorInitializerColon: true 206 | SpaceBeforeInheritanceColon: true 207 | SpaceBeforeParens: ControlStatements 208 | SpaceBeforeParensOptions: 209 | AfterControlStatements: true 210 | AfterForeachMacros: true 211 | AfterFunctionDefinitionName: false 212 | AfterFunctionDeclarationName: false 213 | AfterIfMacros: true 214 | AfterOverloadedOperator: false 215 | AfterRequiresInClause: false 216 | AfterRequiresInExpression: false 217 | BeforeNonEmptyParentheses: false 218 | SpaceAroundPointerQualifiers: Default 219 | SpaceBeforeRangeBasedForLoopColon: true 220 | SpaceInEmptyBlock: false 221 | SpaceInEmptyParentheses: false 222 | SpacesBeforeTrailingComments: 2 223 | SpacesInAngles: Never 224 | SpacesInConditionalStatement: false 225 | SpacesInContainerLiterals: true 226 | SpacesInCStyleCastParentheses: false 227 | SpacesInLineCommentPrefix: 228 | Minimum: 1 229 | Maximum: -1 230 | SpacesInParentheses: false 231 | SpacesInSquareBrackets: false 232 | SpaceBeforeSquareBrackets: false 233 | BitFieldColonSpacing: Both 234 | Standard: Auto 235 | StatementAttributeLikeMacros: 236 | - Q_EMIT 237 | StatementMacros: 238 | - Q_UNUSED 239 | - QT_REQUIRE_VERSION 240 | TabWidth: 8 241 | UseCRLF: false 242 | UseTab: Never 243 | WhitespaceSensitiveMacros: 244 | - STRINGIZE 245 | - PP_STRINGIZE 246 | - BOOST_PP_STRINGIZE 247 | - NS_SWIFT_NAME 248 | - CF_SWIFT_NAME 249 | ... -------------------------------------------------------------------------------- /sparse_transformers/csrc/approx_topk.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | // Count-Min Sketch inspired method - O(n) time complexity with parallel batch processing 10 | torch::Tensor approx_topk_threshold( 11 | const torch::Tensor &scores, 12 | int64_t k) 13 | { 14 | TORCH_CHECK(scores.dim() == 2, "Input scores must be 2D tensor [batch_size, features]"); 15 | TORCH_CHECK(k > 0, "k must be positive"); 16 | 17 | auto batch_size = scores.size(0); 18 | auto feature_size = scores.size(1); 19 | 20 | TORCH_CHECK(k <= feature_size, "k cannot be larger than feature size"); 21 | 22 | auto options = torch::TensorOptions().dtype(scores.dtype()).device(scores.device()); 23 | auto threshold = torch::zeros({batch_size, 1}, options); 24 | 25 | // Sketch parameters 26 | const int num_sketches = 4; 27 | const int sketch_width = std::min(1024L, feature_size / 4); 28 | 29 | // Standard C++ hash function 30 | std::hash hasher; 31 | 32 | // Process each batch item in parallel using at::parallel_for 33 | AT_DISPATCH_FLOATING_TYPES(scores.scalar_type(), "approx_topk_count_min_sketch", [&] 34 | { 35 | auto scores_accessor = scores.accessor(); 36 | auto threshold_accessor = threshold.accessor(); 37 | 38 | // Parallel processing over batch dimension 39 | // Use grain_size of 1 for fine-grained parallelism 40 | at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) { 41 | for (int64_t batch_idx = start; batch_idx < end; ++batch_idx) { 42 | // Initialize sketches with negative infinity (thread-local) 43 | std::vector> sketches(num_sketches, 44 | std::vector(sketch_width, -std::numeric_limits::infinity())); 45 | 46 | // Update sketches with maximum values at hash positions 47 | for (int sketch_idx = 0; sketch_idx < num_sketches; ++sketch_idx) { 48 | for (int64_t feature_idx = 0; feature_idx < feature_size; ++feature_idx) { 49 | // Use different hash functions for each sketch by combining with sketch_idx 50 | int64_t combined_key = sketch_idx * feature_size + feature_idx; 51 | int64_t hash_pos = hasher(combined_key) % sketch_width; 52 | 53 | scalar_t value = scores_accessor[batch_idx][feature_idx]; 54 | sketches[sketch_idx][hash_pos] = std::max(sketches[sketch_idx][hash_pos], value); 55 | } 56 | } 57 | 58 | // Collect all sketch values (thread-local) 59 | std::vector all_sketch_values; 60 | for (const auto& sketch : sketches) { 61 | for (scalar_t val : sketch) { 62 | if (val != -std::numeric_limits::infinity()) { 63 | all_sketch_values.push_back(val); 64 | } 65 | } 66 | } 67 | 68 | if (!all_sketch_values.empty()) { 69 | // Find approximate threshold 70 | int64_t sketch_k = std::max(1L, static_cast(k * all_sketch_values.size() / feature_size)); 71 | sketch_k = std::min(sketch_k, static_cast(all_sketch_values.size())); 72 | 73 | std::nth_element(all_sketch_values.begin(), 74 | all_sketch_values.begin() + sketch_k - 1, 75 | all_sketch_values.end(), 76 | std::greater()); 77 | 78 | // Apply adjustment factor for approximation error 79 | scalar_t adjustment_factor = 0.9; 80 | threshold_accessor[batch_idx][0] = all_sketch_values[sketch_k - 1] * adjustment_factor; 81 | } else { 82 | threshold_accessor[batch_idx][0] = 0.0; 83 | } 84 | } 85 | }); }); 86 | 87 | return threshold; 88 | } -------------------------------------------------------------------------------- /sparse_transformers/csrc/sparse_mlp_op.cpp: -------------------------------------------------------------------------------- 1 | // For TorchScript support 2 | #include 3 | 4 | // For PyTorch C++ extension support 5 | #include 6 | 7 | // For tensor operations 8 | #include 9 | 10 | // For PyTorch's OpenMP wrapper 11 | #include 12 | 13 | // Add pybind11 and namespace 14 | #include 15 | namespace py = pybind11; 16 | 17 | // Add required headers 18 | #include 19 | #include 20 | #include 21 | 22 | // Add device check utilities (only if not CPU-only build) 23 | #ifndef CPU_ONLY 24 | #include 25 | #endif 26 | 27 | // Add custom headers 28 | #include "weight_cache.h" 29 | #include "approx_topk.h" 30 | 31 | // Forward declarations of CPU/CUDA implementations 32 | torch::Tensor sparse_mlp_forward_cpu( 33 | const torch::Tensor &input, 34 | const torch::Tensor &concat_weight, 35 | const torch::Tensor &active_down_weight, 36 | torch::Tensor &down_proj_buffer, 37 | torch::Tensor &combined_proj_buffer, 38 | const std::string &activation_fn, 39 | bool has_gate); 40 | 41 | #ifdef WITH_CUDA 42 | torch::Tensor sparse_mlp_forward_cuda( 43 | const torch::Tensor &input, 44 | const torch::Tensor &concat_weight, 45 | const torch::Tensor &active_down_weight, 46 | torch::Tensor &down_proj_buffer, 47 | torch::Tensor &combined_proj_buffer, 48 | const std::string &activation_fn, 49 | bool has_gate); 50 | #endif 51 | 52 | // Main dispatch function 53 | torch::Tensor sparse_mlp_forward( 54 | const torch::Tensor &input, 55 | const torch::Tensor &concat_weight, 56 | const torch::Tensor &active_down_weight, 57 | torch::Tensor &down_proj_buffer, 58 | torch::Tensor &combined_proj_buffer, 59 | const std::string &activation_fn, 60 | bool has_gate=true) 61 | { 62 | 63 | // Check if input is on CUDA and dispatch accordingly 64 | if (input.is_cuda()) 65 | { 66 | #ifdef WITH_CUDA 67 | return sparse_mlp_forward_cuda(input, concat_weight, active_down_weight, down_proj_buffer, combined_proj_buffer, activation_fn, has_gate); 68 | #else 69 | AT_ERROR("CUDA not available - cannot run on GPU"); 70 | #endif 71 | } 72 | else 73 | { 74 | return sparse_mlp_forward_cpu(input, concat_weight, active_down_weight, down_proj_buffer, combined_proj_buffer, activation_fn, has_gate); 75 | } 76 | } 77 | 78 | // CPU implementation 79 | torch::Tensor sparse_mlp_forward_cpu( 80 | const torch::Tensor &input, 81 | const torch::Tensor &concat_weight, 82 | const torch::Tensor &active_down_weight, 83 | torch::Tensor &down_proj_buffer, 84 | torch::Tensor &combined_proj_buffer, 85 | const std::string &activation_fn, 86 | bool has_gate) 87 | { 88 | // Store original input shape for reshaping output later 89 | auto original_shape = input.sizes().vec(); 90 | bool needs_reshape = input.dim() > 2; 91 | 92 | // Flatten input if it has more than 2 dimensions 93 | torch::Tensor input_2d; 94 | if (needs_reshape) 95 | { 96 | // Flatten all dimensions except the last one (hidden dimension) 97 | auto hidden_size = original_shape.back(); 98 | auto total_batch_size = input.numel() / hidden_size; 99 | input_2d = input.view({total_batch_size, hidden_size}); 100 | } else { 101 | input_2d = input; 102 | } 103 | 104 | const auto batch_size = input_2d.size(0); 105 | const auto hidden_size = input_2d.size(1); 106 | 107 | // Ensure output buffer is correctly sized 108 | // Check both dimensions to avoid resize warnings 109 | const int64_t expected_gate_dim = concat_weight.size(0); 110 | 111 | // For down_proj_buffer: [batch_size, hidden_size] 112 | if (down_proj_buffer.size(0) != batch_size || down_proj_buffer.size(1) != hidden_size) 113 | down_proj_buffer.resize_({batch_size, hidden_size}); 114 | 115 | // For combined_proj_buffer: [batch_size, 2 * gate_size] 116 | if (combined_proj_buffer.size(0) != batch_size || combined_proj_buffer.size(1) != expected_gate_dim) 117 | combined_proj_buffer.resize_({batch_size, expected_gate_dim}); 118 | 119 | // Optimal grain size for heavy matmul operations 120 | const int64_t num_threads = at::get_num_threads(); 121 | int64_t grain_size = 1; 122 | 123 | if (batch_size > num_threads && num_threads > 0) { 124 | // Base calculation: create 2-4x more work chunks than threads 125 | const int64_t target_chunks = num_threads * 3; // 3x oversubscription 126 | grain_size = std::max(int64_t(1), batch_size / target_chunks); 127 | 128 | // Cap at 32 for heavy operations (much smaller than 64 for light ops) 129 | const int64_t max_grain_matmul = 32; 130 | grain_size = std::min(grain_size, max_grain_matmul); 131 | } 132 | 133 | if (has_gate) 134 | { 135 | // Process each batch block in parallel 136 | at::parallel_for(0, batch_size, grain_size, [&](int64_t start, int64_t end) 137 | { 138 | // Process blocks of batches instead of single items 139 | const int64_t block_size = end - start; 140 | const int64_t gate_size = concat_weight.size(0) / 2; 141 | 142 | // Get input block for this thread 143 | auto input_block = input_2d.slice(0, start, end); // [block_size, hidden_size] 144 | 145 | // Get output buffer views for this block 146 | auto combined_proj_block = combined_proj_buffer.slice(0, start, end); // [block_size, 2*gate_size] 147 | auto down_proj_block = down_proj_buffer.slice(0, start, end); // [block_size, hidden_size] 148 | 149 | // Perform batch matrix multiplication for gate and up projections 150 | // This is more efficient than individual matmuls 151 | torch::matmul_out(combined_proj_block, input_block, concat_weight.t()); 152 | 153 | // Split into gate and up projections 154 | auto gate_proj = combined_proj_block.narrow(1, 0, gate_size); // [block_size, gate_size] 155 | auto up_proj = combined_proj_block.narrow(1, gate_size, gate_size); // [block_size, gate_size] 156 | 157 | gate_proj.relu_(); // In-place sigmoid 158 | gate_proj.mul_(up_proj); // In-place element-wise multiplication 159 | 160 | // Final projection to output dimension 161 | torch::matmul_out(down_proj_block, gate_proj, active_down_weight.t()); 162 | }); 163 | } 164 | else 165 | { 166 | // Process each batch block in parallel 167 | at::parallel_for(0, batch_size, grain_size, [&](int64_t start, int64_t end) 168 | { 169 | // Process blocks of batches instead of single items 170 | const int64_t block_size = end - start; 171 | 172 | // Get input block for this thread 173 | auto input_block = input_2d.slice(0, start, end); // [block_size, hidden_size] 174 | 175 | // Get output buffer views for this block 176 | auto combined_proj_block = combined_proj_buffer.slice(0, start, end); // [block_size, 2*gate_size] 177 | auto down_proj_block = down_proj_buffer.slice(0, start, end); // [block_size, hidden_size] 178 | 179 | // Perform batch matrix multiplication for gate and up projections 180 | // This is more efficient than individual matmuls 181 | torch::matmul_out(combined_proj_block, input_block, concat_weight.t()); 182 | 183 | combined_proj_block.relu_(); // In-place sigmoid 184 | 185 | // Final projection to output dimension 186 | torch::matmul_out(down_proj_block, combined_proj_block, active_down_weight.t()); 187 | }); 188 | } 189 | 190 | // Reshape output back to original shape if input was multi-dimensional 191 | return needs_reshape ? down_proj_buffer.view(original_shape) : down_proj_buffer; 192 | } 193 | 194 | // Register TorchScript custom classes and operators 195 | TORCH_LIBRARY(sparse_mlp, m) 196 | { 197 | // Register the optimized weight cache 198 | m.class_("WeightCache") 199 | .def(torch::init()) 200 | .def("update_active_weights", &WeightCache::update_active_weights) 201 | .def("get_concat_weight", &WeightCache::get_concat_weight) 202 | .def("get_active_down_weight", &WeightCache::get_active_down_weight); 203 | 204 | // Register sparse MLP operator 205 | m.def("forward", sparse_mlp_forward); 206 | 207 | // Register Count-Min Sketch approximate top-k threshold operator 208 | m.def("approx_topk_threshold", approx_topk_threshold); 209 | } -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Load activation capture first to initialize registry 2 | from . import activation_capture 3 | 4 | from . import models 5 | from . import utilities -------------------------------------------------------------------------------- /src/activation_capture.py: -------------------------------------------------------------------------------- 1 | 2 | from enum import Enum 3 | from typing import List 4 | 5 | 6 | class Hook(Enum): 7 | IN = "IN" 8 | ACT = "ACT" 9 | UP = "UP" 10 | OUT = "OUT" 11 | 12 | 13 | class ActivationCapture(): 14 | """Helper class to capture activations from model layers.""" 15 | hooks_available: List[Hook] = [Hook.IN, Hook.ACT, Hook.UP, Hook.OUT] 16 | 17 | def __init__(self, model): 18 | self.model = model 19 | self.mlp_activations = { 20 | hook: {} for hook in self.hooks_available 21 | } 22 | self.handles = [] 23 | 24 | def _register_in_hook(self, layer_idx, layer): 25 | def hook(module, input, output): 26 | # Just detach, don't clone or move to CPU yet 27 | self.mlp_activations[Hook.IN][layer_idx] = input[0].clone().detach() 28 | return output 29 | handle = layer.register_forward_hook(hook) 30 | return handle 31 | 32 | def _register_act_hook(self, layer_idx, layer): 33 | def hook(module, input, output): 34 | # Just detach, don't clone or move to CPU yet 35 | self.mlp_activations[Hook.ACT][layer_idx] = output[0].clone().detach() 36 | return output 37 | handle = layer.mlp.act_fn.register_forward_hook(hook) 38 | return handle 39 | 40 | def _register_up_hook(self, layer_idx, layer): 41 | def hook(module, input, output): 42 | # Just detach, don't clone or move to CPU yet 43 | self.mlp_activations[Hook.UP][layer_idx] = input[0].clone().detach() 44 | return output 45 | handle = layer.mlp.down_proj.register_forward_hook(hook) 46 | return handle 47 | 48 | def _register_out_hook(self, layer_idx, layer): 49 | def hook(module, input, output): 50 | # Just detach, don't clone or move to CPU yet 51 | self.mlp_activations[Hook.OUT][layer_idx] = output.clone().detach() 52 | return output 53 | handle = layer.mlp.register_forward_hook(hook) 54 | return handle 55 | 56 | def get_layers(self): 57 | return self.model.get_decoder().layers 58 | 59 | def register_hooks(self, hooks): 60 | """Register forward hooks to capture activations.""" 61 | # Clear any existing hooks 62 | self.remove_hooks() 63 | 64 | # Hook into each transformer layer 65 | for i, layer in enumerate(self.get_layers()): 66 | # Hooks capturing inputs to the MLP layer 67 | if Hook.IN in hooks and Hook.IN in self.hooks_available: 68 | handle = self._register_in_hook(i, layer) 69 | if handle is not None: 70 | self.handles.append(handle) 71 | 72 | # Hooks capturing inputs to the activation function 73 | if Hook.ACT in hooks and Hook.ACT in self.hooks_available: 74 | handle = self._register_act_hook(i, layer) 75 | if handle is not None: 76 | self.handles.append(handle) 77 | 78 | # Hooks capturing inputs to the down projection 79 | if Hook.UP in hooks and Hook.UP in self.hooks_available: 80 | handle = self._register_up_hook(i, layer) 81 | if handle is not None: 82 | self.handles.append(handle) 83 | 84 | # Hooks capturing the final MLP output 85 | if Hook.OUT in hooks and Hook.OUT in self.hooks_available: 86 | handle = self._register_out_hook(i, layer) 87 | if handle is not None: 88 | self.handles.append(handle) 89 | 90 | 91 | def remove_hooks(self): 92 | """Remove all registered hooks.""" 93 | for handle in self.handles: 94 | handle.remove() 95 | self.handles = [] 96 | 97 | def clear_captures(self): 98 | """Clear captured activations.""" 99 | self.mlp_activations = { 100 | hook: {} for hook in self.hooks_available 101 | } 102 | 103 | 104 | 105 | LOOKUP_DICT = {} 106 | 107 | def register(*model_types): 108 | ''' 109 | A decorator to record model-specific activation capture implementations. 110 | Arguments 111 | ---------- 112 | *model_types: str 113 | Variable-length list of model types to register the given implementation under. 114 | Returns 115 | ---------- 116 | callable: 117 | Decorator function to register model implementation. 118 | Examples 119 | ---------- 120 | >>> @register('opt') 121 | ... class ActivationCaptureOPT(ActivationCapture): # This class definition gets registered for opt 122 | ''' 123 | def decorator(cls): 124 | for model_type in model_types: 125 | if model_type in LOOKUP_DICT: 126 | raise LookupError(f"{model_type} already present") 127 | LOOKUP_DICT[model_type] = cls 128 | return cls 129 | return decorator 130 | 131 | def capture_model(model): 132 | ''' 133 | Returns an ActivationCapture instance of the correct type for the given model using the 134 | lookup dict (or a default implementation if no model-specific implementation is found). 135 | Arguments 136 | ---------- 137 | model: PreTrainedModel 138 | Model object to use as lookup key and argument to instantiate the activation capture. 139 | Returns 140 | ---------- 141 | ActivationCapture: 142 | ActivationCapture implementation stored in registry for the provided model type 143 | (or default implementation as fallback if none found). 144 | Examples 145 | ---------- 146 | >>> @register('opt') 147 | ... class ActivationCaptureOPT(ActivationCapture): # This class definition gets registered for opt 148 | >>> model.activation_capture = capture_model(model) # Loads activation capture class from registry 149 | >>> model.activation_capture.register_hooks() # Use activation capture class to record model activations 150 | __main__.DLRM_Net 151 | ''' 152 | model_type = model.config.model_type 153 | 154 | # Use model-specific implementation if available, otherwise fall back to default 155 | if model_type in LOOKUP_DICT: 156 | act_cls = LOOKUP_DICT[model_type] 157 | else: 158 | act_cls = ActivationCapture 159 | 160 | return act_cls(model) -------------------------------------------------------------------------------- /src/configuration_skip.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | import os 3 | from typing import Union, Any, Type 4 | 5 | 6 | 7 | def build_skip_config(base_config_class: type[PretrainedConfig], model_type_name: str) -> type[PretrainedConfig]: 8 | class SkipConnectionConfig(base_config_class): 9 | model_type: str = model_type_name 10 | has_no_defaults_at_init: bool = True 11 | 12 | def __init__(self, 13 | sparsities: float, 14 | sparsity_method: str = "naive", 15 | predictor_loss_type: str = "bce", 16 | predictor_temperature: float = 1.0, 17 | predictor_loss_alpha: float = 1.0, 18 | predictor_loss_weight: float = 0.1, 19 | use_optimized_weight_cache: bool = True, 20 | capture_activations: str = None, 21 | **kwargs): 22 | self.sparsities = sparsities 23 | self.sparsity_method = sparsity_method 24 | self.predictor_loss_type = predictor_loss_type 25 | self.predictor_temperature = predictor_temperature 26 | self.predictor_loss_alpha = predictor_loss_alpha 27 | self.predictor_loss_weight = predictor_loss_weight 28 | self.use_optimized_weight_cache = use_optimized_weight_cache 29 | self.capture_activations = capture_activations 30 | super().__init__(**kwargs) 31 | 32 | 33 | @classmethod 34 | def from_json_file(cls, json_file: Union[str, os.PathLike]): 35 | """ 36 | Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters. 37 | 38 | Args: 39 | json_file (`str` or `os.PathLike`): 40 | Path to the JSON file containing the parameters. 41 | 42 | Returns: 43 | [`PretrainedConfig`]: The configuration object instantiated from that JSON file. 44 | 45 | """ 46 | config_dict = cls._dict_from_json_file(json_file) 47 | return cls(**config_dict) 48 | 49 | @classmethod 50 | def from_dict(cls, config_dict: dict[str, Any], **kwargs): 51 | if "name_or_path" in kwargs and ("name_or_path" in config_dict or "_name_or_path" in config_dict): 52 | del kwargs["name_or_path"] 53 | return super().from_dict(config_dict, **kwargs) 54 | return SkipConnectionConfig 55 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import llama 2 | from . import qwen2 3 | from . import mistral 4 | from . import phi3 5 | from . import gemma3n 6 | from . import opt 7 | # from . import dia -------------------------------------------------------------------------------- /src/models/gemma3n/__init__.py: -------------------------------------------------------------------------------- 1 | from . import configuration_gemma_skip 2 | from . import modelling_gemma_skip 3 | 4 | from transformers import AutoConfig, AutoModelForCausalLM 5 | from .configuration_gemma_skip import Gemma3nSkipConnectionConfig 6 | from .modelling_gemma_skip import Gemma3nSkipConnectionForCausalLM 7 | AutoConfig.register("gemma3n-skip", Gemma3nSkipConnectionConfig) 8 | AutoModelForCausalLM.register(Gemma3nSkipConnectionConfig, Gemma3nSkipConnectionForCausalLM) 9 | 10 | __all__ = [configuration_gemma_skip, modelling_gemma_skip] -------------------------------------------------------------------------------- /src/models/gemma3n/configuration_gemma_skip.py: -------------------------------------------------------------------------------- 1 | from transformers import Gemma3nTextConfig 2 | from src.configuration_skip import build_skip_config 3 | 4 | Gemma3nSkipConnectionConfig = build_skip_config(Gemma3nTextConfig, "gemma3n-skip") -------------------------------------------------------------------------------- /src/models/llama/__init__.py: -------------------------------------------------------------------------------- 1 | from . import configuration_llama_skip 2 | from . import modelling_llama_skip 3 | 4 | from transformers import AutoConfig, AutoModelForCausalLM 5 | from .configuration_llama_skip import LlamaSkipConnectionConfig 6 | from .modelling_llama_skip import LlamaSkipConnectionForCausalLM 7 | AutoConfig.register("llama-skip", LlamaSkipConnectionConfig) 8 | AutoModelForCausalLM.register(LlamaSkipConnectionConfig, LlamaSkipConnectionForCausalLM) 9 | 10 | __all__ = [configuration_llama_skip, modelling_llama_skip] -------------------------------------------------------------------------------- /src/models/llama/configuration_llama_skip.py: -------------------------------------------------------------------------------- 1 | from transformers import LlamaConfig 2 | from optimum.utils import NormalizedTextConfig, MistralDummyPastKeyValuesGenerator, DummyTextInputGenerator 3 | from optimum.exporters.onnx.config import TextDecoderWithPositionIdsOnnxConfig 4 | from src.configuration_skip import build_skip_config 5 | 6 | 7 | LlamaSkipConnectionConfig = build_skip_config(LlamaConfig, "llama-skip") 8 | 9 | 10 | class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): 11 | DEFAULT_ONNX_OPSET = 14 # Llama now uses F.scaled_dot_product_attention by default for torch>=2.1.1. 12 | 13 | DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) 14 | DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator 15 | NORMALIZED_CONFIG_CLASS = NormalizedTextConfig -------------------------------------------------------------------------------- /src/models/llama/modelling_llama_skip.py: -------------------------------------------------------------------------------- 1 | # limitations under the License. 2 | from typing import Union 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from transformers.modeling_utils import PreTrainedModel 8 | from transformers.models.llama.modeling_llama import( 9 | LlamaRotaryEmbedding, 10 | LlamaMLP, LlamaAttention 11 | ) 12 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter 13 | from transformers.utils import logging 14 | from transformers.cache_utils import Cache 15 | from transformers.modeling_utils import PreTrainedModel 16 | from transformers.utils.import_utils import is_torch_flex_attn_available 17 | 18 | if is_torch_flex_attn_available(): 19 | from torch.nn.attention.flex_attention import BlockMask 20 | 21 | from transformers.integrations.flex_attention import make_flex_block_causal_mask 22 | 23 | 24 | from src.models.llama.configuration_llama_skip import LlamaSkipConnectionConfig 25 | from src.modeling_skip import SkipMLP, SkipDecoderLayer, build_skip_connection_model, build_skip_connection_model_for_causal_lm 26 | 27 | 28 | logger = logging.get_logger(__name__) 29 | 30 | class LlamaSkipRMSNorm(nn.Module): 31 | def __init__(self, hidden_size, eps=1e-5): 32 | """ 33 | LlamaSkipRMSNorm is equivalent to T5LayerNorm 34 | """ 35 | super().__init__() 36 | self.eps = eps 37 | self.weight = nn.Parameter(torch.ones(hidden_size)) 38 | 39 | def _norm(self, x): 40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 41 | 42 | def forward(self, x): 43 | output = self._norm(x.float()).type_as(x) 44 | return output * self.weight 45 | 46 | def extra_repr(self): 47 | return f"{tuple(self.weight.shape)}, eps={self.eps}" 48 | 49 | class LlamaSkipDecoderLayer(SkipDecoderLayer): 50 | def _init_components(self, config, layer_idx): 51 | self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) 52 | self.input_layernorm = LlamaSkipRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 53 | self.post_attention_layernorm = LlamaSkipRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 54 | 55 | def _set_mlp_train(self, config, layer_idx): 56 | self.mlp = LlamaMLP(config) 57 | 58 | def _set_mlp_inference(self, config, layer_idx): 59 | self.mlp = SkipMLP( 60 | config.hidden_size, 61 | config.intermediate_size, 62 | config.sparsities[layer_idx], 63 | config.mlp_bias, 64 | config.hidden_act, 65 | getattr(config, 'use_weight_cache', True) 66 | ) 67 | 68 | 69 | 70 | class LlamaSkipPreTrainedModel(PreTrainedModel): 71 | config_class = LlamaSkipConnectionConfig 72 | base_model_prefix = "model" 73 | supports_gradient_checkpointing = True 74 | _no_split_modules = ["LlamaSkipDecoderLayer"] 75 | _skip_keys_device_placement = ["past_key_values"] 76 | _supports_flash_attn_2 = True 77 | _supports_sdpa = True 78 | _supports_flex_attn = True 79 | _supports_cache_class = True 80 | _supports_quantized_cache = True 81 | _supports_static_cache = True 82 | _supports_attention_backend = True 83 | 84 | def _init_weights(self, module): 85 | std = self.config.initializer_range 86 | if isinstance(module, nn.Linear): 87 | module.weight.data.normal_(mean=0.0, std=std) 88 | if module.bias is not None: 89 | module.bias.data.zero_() 90 | elif isinstance(module, nn.Embedding): 91 | module.weight.data.normal_(mean=0.0, std=std) 92 | if module.padding_idx is not None: 93 | module.weight.data[module.padding_idx].zero_() 94 | elif isinstance(module, LlamaSkipRMSNorm): 95 | module.weight.data.fill_(1.0) 96 | 97 | 98 | LlamaSkipConnectionModelBase = build_skip_connection_model(LlamaSkipPreTrainedModel) 99 | 100 | class LlamaSkipConnectionModel(LlamaSkipConnectionModelBase): 101 | def _init_components(self, config): 102 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) # type: ignore 103 | self.layers = nn.ModuleList( 104 | [LlamaSkipDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 105 | ) 106 | self.norm = LlamaSkipRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 107 | self.rotary_emb = LlamaRotaryEmbedding(config=config) 108 | 109 | def _update_causal_mask( 110 | self, 111 | attention_mask: Union[torch.Tensor, "BlockMask"], # type: ignore 112 | input_tensor: torch.Tensor, 113 | cache_position: torch.Tensor, 114 | past_key_values: Cache, 115 | output_attentions: bool = False, 116 | ): 117 | if self.config._attn_implementation == "flash_attention_2": 118 | if attention_mask is not None and (attention_mask == 0.0).any(): 119 | return attention_mask 120 | return None 121 | if self.config._attn_implementation == "flex_attention": 122 | if isinstance(attention_mask, torch.Tensor): 123 | attention_mask = make_flex_block_causal_mask(attention_mask) # type: ignore 124 | return attention_mask 125 | 126 | # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in 127 | # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail 128 | # to infer the attention mask. 129 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 130 | using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False 131 | 132 | # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward 133 | if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: 134 | if AttentionMaskConverter._ignore_causal_mask_sdpa( 135 | attention_mask, 136 | inputs_embeds=input_tensor, 137 | past_key_values_length=past_seen_tokens, 138 | is_training=self.training, 139 | ): 140 | return None 141 | 142 | dtype = input_tensor.dtype 143 | sequence_length = input_tensor.shape[1] 144 | if using_compilable_cache: 145 | target_length = past_key_values.get_max_cache_shape() 146 | else: 147 | target_length = ( 148 | attention_mask.shape[-1] 149 | if isinstance(attention_mask, torch.Tensor) 150 | else past_seen_tokens + sequence_length + 1 151 | ) 152 | 153 | # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). 154 | causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( 155 | attention_mask, 156 | sequence_length=sequence_length, 157 | target_length=target_length, # type: ignore 158 | dtype=dtype, 159 | cache_position=cache_position, 160 | batch_size=input_tensor.shape[0], 161 | ) 162 | 163 | if ( 164 | self.config._attn_implementation == "sdpa" 165 | and attention_mask is not None 166 | and attention_mask.device.type in ["cuda", "xpu", "npu"] 167 | and not output_attentions 168 | ): 169 | # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when 170 | # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. 171 | # Details: https://github.com/pytorch/pytorch/issues/110213 172 | min_dtype = torch.finfo(dtype).min 173 | causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) # type: ignore 174 | 175 | return causal_mask 176 | 177 | @staticmethod 178 | def _prepare_4d_causal_attention_mask_with_cache_position( 179 | attention_mask: torch.Tensor, 180 | sequence_length: int, 181 | target_length: int, 182 | dtype: torch.dtype, 183 | cache_position: torch.Tensor, 184 | batch_size: int, 185 | **kwargs, 186 | ): 187 | """ 188 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape 189 | `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. 190 | 191 | Args: 192 | attention_mask (`torch.Tensor`): 193 | A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape 194 | `(batch_size, 1, query_length, key_value_length)`. 195 | sequence_length (`int`): 196 | The sequence length being processed. 197 | target_length (`int`): 198 | The target length: when generating with static cache, the mask should be as long as the static cache, 199 | to account for the 0 padding, the part of the cache that is not filled yet. 200 | dtype (`torch.dtype`): 201 | The dtype to use for the 4D attention mask. 202 | cache_position (`torch.Tensor`): 203 | Indices depicting the position of the input sequence tokens in the sequence. 204 | batch_size (`torch.Tensor`): 205 | Batch size. 206 | """ 207 | if attention_mask is not None and attention_mask.dim() == 4: 208 | # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. 209 | causal_mask = attention_mask 210 | else: 211 | min_dtype = torch.finfo(dtype).min 212 | causal_mask = torch.full( 213 | (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device 214 | ) 215 | if sequence_length != 1: 216 | causal_mask = torch.triu(causal_mask, diagonal=1) 217 | causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) 218 | causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) 219 | if attention_mask is not None: 220 | causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit 221 | mask_length = attention_mask.shape[-1] 222 | padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( 223 | causal_mask.device 224 | ) 225 | padding_mask = padding_mask == 0 226 | causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( 227 | padding_mask, min_dtype 228 | ) 229 | 230 | return causal_mask 231 | 232 | LlamaSkipConnectionForCausalLM = build_skip_connection_model_for_causal_lm(LlamaSkipPreTrainedModel, LlamaSkipConnectionModel) 233 | -------------------------------------------------------------------------------- /src/models/mistral/__init__.py: -------------------------------------------------------------------------------- 1 | from . import configuration_mistral_skip 2 | from . import modelling_mistral_skip 3 | 4 | from transformers import AutoConfig, AutoModelForCausalLM 5 | from .configuration_mistral_skip import MistralSkipConnectionConfig 6 | from .modelling_mistral_skip import MistralSkipConnectionForCausalLM 7 | AutoConfig.register("mistral-skip", MistralSkipConnectionConfig) 8 | AutoModelForCausalLM.register(MistralSkipConnectionConfig, MistralSkipConnectionForCausalLM) 9 | 10 | __all__ = [configuration_mistral_skip, modelling_mistral_skip] -------------------------------------------------------------------------------- /src/models/mistral/configuration_mistral_skip.py: -------------------------------------------------------------------------------- 1 | from transformers import MistralConfig 2 | from src.configuration_skip import build_skip_config 3 | 4 | MistralSkipConnectionConfig = build_skip_config(MistralConfig, "mistral-skip") -------------------------------------------------------------------------------- /src/models/mistral/modelling_mistral_skip.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter 7 | from transformers.utils import logging 8 | from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache 9 | from transformers.modeling_utils import PreTrainedModel 10 | 11 | 12 | from transformers.models.mistral.modeling_mistral import( 13 | MistralMLP, MistralAttention, MistralRMSNorm, MistralRotaryEmbedding, 14 | ) 15 | 16 | from transformers.utils.import_utils import is_torch_flex_attn_available 17 | 18 | if is_torch_flex_attn_available(): 19 | from torch.nn.attention.flex_attention import BlockMask 20 | 21 | from transformers.integrations.flex_attention import make_flex_block_causal_mask 22 | 23 | 24 | from src.models.mistral.configuration_mistral_skip import MistralSkipConnectionConfig 25 | from src.modeling_skip import SkipMLP, SkipDecoderLayer, build_skip_connection_model, build_skip_connection_model_for_causal_lm 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | class MistralSkipDecoderLayer(SkipDecoderLayer): 31 | def _init_components(self, config, layer_idx): 32 | self.self_attn = MistralAttention(config=config, layer_idx=layer_idx) 33 | self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 34 | self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 35 | 36 | def _set_mlp_train(self, config, layer_idx): 37 | self.mlp = MistralMLP(config) 38 | 39 | def _set_mlp_inference(self, config, layer_idx): 40 | self.mlp = SkipMLP( 41 | config.hidden_size, 42 | config.intermediate_size, 43 | config.sparsity, 44 | False, 45 | "silu" 46 | ) 47 | 48 | class MistralSkipPreTrainedModel(PreTrainedModel): 49 | config_class = MistralSkipConnectionConfig 50 | base_model_prefix = "model" 51 | supports_gradient_checkpointing = True 52 | _no_split_modules = ["MistralSkipDecoderLayer"] 53 | _skip_keys_device_placement = ["past_key_values"] 54 | _supports_flash_attn_2 = True 55 | _supports_sdpa = True 56 | _supports_flex_attn = True 57 | _supports_cache_class = True 58 | _supports_quantized_cache = True 59 | _supports_static_cache = True 60 | _supports_attention_backend = True 61 | 62 | def _init_weights(self, module): 63 | std = self.config.initializer_range 64 | if isinstance(module, nn.Linear): 65 | module.weight.data.normal_(mean=0.0, std=std) 66 | if module.bias is not None: 67 | module.bias.data.zero_() 68 | elif isinstance(module, nn.Embedding): 69 | module.weight.data.normal_(mean=0.0, std=std) 70 | if module.padding_idx is not None: 71 | module.weight.data[module.padding_idx].zero_() 72 | elif isinstance(module, MistralRMSNorm): 73 | module.weight.data.fill_(1.0) 74 | 75 | 76 | MistralSkipConnectionModelBase = build_skip_connection_model(MistralSkipPreTrainedModel) 77 | 78 | class MistralSkipConnectionModel(MistralSkipConnectionModelBase): 79 | def _init_components(self, config): 80 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) # type: ignore 81 | self.layers = nn.ModuleList( 82 | [MistralSkipDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 83 | ) 84 | self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 85 | self.rotary_emb = MistralRotaryEmbedding(config=config) 86 | 87 | def _update_causal_mask( 88 | self, 89 | attention_mask: Union[torch.Tensor, "BlockMask"], 90 | input_tensor: torch.Tensor, 91 | cache_position: torch.Tensor, 92 | past_key_values: Cache, 93 | output_attentions: bool = False, 94 | ): 95 | if self.config._attn_implementation == "flash_attention_2": 96 | if attention_mask is not None and past_key_values is not None: 97 | is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] 98 | if is_padding_right: 99 | raise ValueError( 100 | "You are attempting to perform batched generation with padding_side='right'" 101 | " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " 102 | " call `tokenizer.padding_side = 'left'` before tokenizing the input. " 103 | ) 104 | if attention_mask is not None and 0.0 in attention_mask: 105 | return attention_mask 106 | return None 107 | if self.config._attn_implementation == "flex_attention": 108 | if isinstance(attention_mask, torch.Tensor): 109 | attention_mask = make_flex_block_causal_mask(attention_mask) 110 | return attention_mask 111 | 112 | # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in 113 | # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail 114 | # to infer the attention mask. 115 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 116 | using_static_cache = isinstance(past_key_values, StaticCache) 117 | using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) 118 | 119 | # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward 120 | if ( 121 | self.config._attn_implementation == "sdpa" 122 | and not (using_static_cache or using_sliding_window_cache) 123 | and not output_attentions 124 | ): 125 | if AttentionMaskConverter._ignore_causal_mask_sdpa( 126 | attention_mask, 127 | inputs_embeds=input_tensor, 128 | past_key_values_length=past_seen_tokens, 129 | sliding_window=self.config.sliding_window, 130 | is_training=self.training, 131 | ): 132 | return None 133 | 134 | dtype = input_tensor.dtype 135 | min_dtype = torch.finfo(dtype).min 136 | sequence_length = input_tensor.shape[1] 137 | # SlidingWindowCache or StaticCache 138 | if using_sliding_window_cache or using_static_cache: 139 | target_length = past_key_values.get_max_cache_shape() 140 | # DynamicCache or no cache 141 | else: 142 | target_length = ( 143 | attention_mask.shape[-1] 144 | if isinstance(attention_mask, torch.Tensor) 145 | else past_seen_tokens + sequence_length + 1 146 | ) 147 | 148 | # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). 149 | causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( 150 | attention_mask, 151 | sequence_length=sequence_length, 152 | target_length=target_length, 153 | dtype=dtype, 154 | cache_position=cache_position, 155 | batch_size=input_tensor.shape[0], 156 | config=self.config, 157 | past_key_values=past_key_values, 158 | ) 159 | 160 | if ( 161 | self.config._attn_implementation == "sdpa" 162 | and attention_mask is not None 163 | and attention_mask.device.type in ["cuda", "xpu", "npu"] 164 | and not output_attentions 165 | ): 166 | # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when 167 | # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. 168 | # Details: https://github.com/pytorch/pytorch/issues/110213 169 | causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) 170 | 171 | return causal_mask 172 | 173 | @staticmethod 174 | def _prepare_4d_causal_attention_mask_with_cache_position( 175 | attention_mask: torch.Tensor, 176 | sequence_length: int, 177 | target_length: int, 178 | dtype: torch.dtype, 179 | cache_position: torch.Tensor, 180 | batch_size: int, 181 | config: MistralSkipConnectionConfig, 182 | past_key_values: Cache, 183 | ): 184 | """ 185 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape 186 | `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. 187 | 188 | Args: 189 | attention_mask (`torch.Tensor`): 190 | A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. 191 | sequence_length (`int`): 192 | The sequence length being processed. 193 | target_length (`int`): 194 | The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. 195 | dtype (`torch.dtype`): 196 | The dtype to use for the 4D attention mask. 197 | cache_position (`torch.Tensor`): 198 | Indices depicting the position of the input sequence tokens in the sequence. 199 | batch_size (`torch.Tensor`): 200 | Batch size. 201 | config (`MistralConfig`): 202 | The model's configuration class 203 | past_key_values (`Cache`): 204 | The cache class that is being used currently to generate 205 | """ 206 | if attention_mask is not None and attention_mask.dim() == 4: 207 | # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. 208 | causal_mask = attention_mask 209 | else: 210 | min_dtype = torch.finfo(dtype).min 211 | causal_mask = torch.full( 212 | (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device 213 | ) 214 | diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( 215 | -1, 1 216 | ) 217 | text_config = config.get_text_config() 218 | if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: 219 | # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also 220 | # the check is needed to verify is current checkpoint was trained with sliding window or not 221 | if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: 222 | sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( 223 | cache_position.reshape(-1, 1) - text_config.sliding_window 224 | ) 225 | diagonal_attend_mask.bitwise_or_(sliding_attend_mask) 226 | causal_mask *= diagonal_attend_mask 227 | causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) 228 | if attention_mask is not None: 229 | causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit 230 | if attention_mask.shape[-1] > target_length: 231 | attention_mask = attention_mask[:, :target_length] 232 | mask_length = attention_mask.shape[-1] 233 | padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( 234 | causal_mask.device 235 | ) 236 | padding_mask = padding_mask == 0 237 | causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( 238 | padding_mask, min_dtype 239 | ) 240 | return causal_mask 241 | 242 | MistralSkipConnectionForCausalLM = build_skip_connection_model_for_causal_lm(MistralSkipPreTrainedModel, MistralSkipConnectionModel) -------------------------------------------------------------------------------- /src/models/opt/__init__.py: -------------------------------------------------------------------------------- 1 | from . import configuration_opt_skip 2 | from . import modelling_opt_skip 3 | from . import activation_capture_opt 4 | 5 | from transformers import AutoConfig, AutoModelForCausalLM 6 | from .configuration_opt_skip import OPTSkipConnectionConfig 7 | from .modelling_opt_skip import OPTSkipConnectionForCausalLM 8 | AutoConfig.register("opt-skip", OPTSkipConnectionConfig) 9 | AutoModelForCausalLM.register(OPTSkipConnectionConfig, OPTSkipConnectionForCausalLM) 10 | 11 | __all__ = [configuration_opt_skip, modelling_opt_skip] -------------------------------------------------------------------------------- /src/models/opt/activation_capture_opt.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from src.activation_capture import ActivationCapture, Hook, register 4 | 5 | @register('opt') 6 | class ActivationCaptureOpt(ActivationCapture): 7 | hooks_available: List[Hook] = [Hook.IN, Hook.ACT, Hook.OUT] 8 | 9 | def _register_act_hook(self, layer_idx, layer): 10 | def hook(module, input, output): 11 | # Just detach, don't clone or move to CPU yet 12 | self.mlp_activations[Hook.ACT][layer_idx] = input[0].clone().detach() 13 | return output 14 | handle = layer.activation_fn.register_forward_hook(hook) 15 | return handle 16 | 17 | def _register_out_hook(self, layer_idx, layer): 18 | def hook(module, input, output): 19 | # Just detach, don't clone or move to CPU yet 20 | self.mlp_activations[Hook.OUT][layer_idx] = output.clone().detach() 21 | return output 22 | handle = layer.fc2.register_forward_hook(hook) 23 | return handle 24 | 25 | def get_layers(self): 26 | return self.model.model.decoder.layers -------------------------------------------------------------------------------- /src/models/opt/configuration_opt_skip.py: -------------------------------------------------------------------------------- 1 | from transformers import OPTConfig, PretrainedConfig 2 | import os 3 | from typing import Union, Any 4 | from src.configuration_skip import build_skip_config 5 | 6 | OPTSkipConnectionConfig: type[OPTConfig] = build_skip_config(OPTConfig, "opt-skip") -------------------------------------------------------------------------------- /src/models/phi3/__init__.py: -------------------------------------------------------------------------------- 1 | from . import configuration_phi_skip 2 | from . import modelling_phi_skip 3 | 4 | from transformers import AutoConfig, AutoModelForCausalLM 5 | from .configuration_phi_skip import Phi3SkipConnectionConfig 6 | from .modelling_phi_skip import Phi3SkipConnectionForCausalLM 7 | AutoConfig.register("phi3-skip", Phi3SkipConnectionConfig) 8 | AutoModelForCausalLM.register(Phi3SkipConnectionConfig, Phi3SkipConnectionForCausalLM) 9 | 10 | __all__ = [configuration_phi_skip, modelling_phi_skip] -------------------------------------------------------------------------------- /src/models/phi3/configuration_phi_skip.py: -------------------------------------------------------------------------------- 1 | from transformers import Phi3Config, PretrainedConfig 2 | import os 3 | from typing import Union, Any 4 | from src.configuration_skip import build_skip_config 5 | 6 | Phi3SkipConnectionConfig: type[Phi3Config] = build_skip_config(Phi3Config, "phi3-skip") 7 | -------------------------------------------------------------------------------- /src/models/qwen2/__init__.py: -------------------------------------------------------------------------------- 1 | from . import configuration_qwen_skip 2 | from . import modelling_qwen_skip 3 | 4 | from transformers import AutoConfig, AutoModelForCausalLM 5 | from .configuration_qwen_skip import Qwen2SkipConnectionConfig 6 | from .modelling_qwen_skip import Qwen2SkipConnectionForCausalLM 7 | AutoConfig.register("qwen2-skip", Qwen2SkipConnectionConfig) 8 | AutoModelForCausalLM.register(Qwen2SkipConnectionConfig, Qwen2SkipConnectionForCausalLM) 9 | 10 | __all__ = [configuration_qwen_skip, modelling_qwen_skip] -------------------------------------------------------------------------------- /src/models/qwen2/configuration_qwen_skip.py: -------------------------------------------------------------------------------- 1 | from transformers import Qwen2Config 2 | from src.configuration_skip import build_skip_config 3 | 4 | Qwen2SkipConnectionConfig = build_skip_config(Qwen2Config, "qwen2-skip") 5 | -------------------------------------------------------------------------------- /src/models/qwen2/modelling_qwen_skip.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from transformers.modeling_attn_mask_utils import AttentionMaskConverter 7 | from transformers.utils import logging 8 | from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache 9 | from transformers.modeling_utils import PreTrainedModel 10 | from transformers.utils.import_utils import is_torch_flex_attn_available 11 | 12 | if is_torch_flex_attn_available(): 13 | from torch.nn.attention.flex_attention import BlockMask 14 | 15 | from transformers.integrations.flex_attention import make_flex_block_causal_mask 16 | 17 | from transformers.models.qwen2.modeling_qwen2 import( 18 | Qwen2MLP, Qwen2Attention, Qwen2RMSNorm, Qwen2RotaryEmbedding, 19 | ) 20 | 21 | 22 | from src.models.qwen2.configuration_qwen_skip import Qwen2SkipConnectionConfig 23 | from src.modeling_skip import SkipMLP, SkipDecoderLayer, build_skip_connection_model, build_skip_connection_model_for_causal_lm 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | 28 | class Qwen2SkipDecoderLayer(SkipDecoderLayer): 29 | def _init_components(self, config, layer_idx): 30 | self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) 31 | self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 32 | self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 33 | if config.use_sliding_window and config._attn_implementation != "flash_attention_2": 34 | logger.warning( 35 | f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " 36 | "unexpected results may be encountered." 37 | ) 38 | 39 | def _set_mlp_train(self, config, layer_idx): 40 | self.mlp = Qwen2MLP(config) 41 | 42 | def _set_mlp_inference(self, config, layer_idx): 43 | self.mlp = SkipMLP( 44 | config.hidden_size, 45 | config.intermediate_size, 46 | config.sparsity, 47 | False, 48 | "silu" 49 | ) 50 | 51 | class Qwen2SkipPreTrainedModel(PreTrainedModel): 52 | config_class = Qwen2SkipConnectionConfig 53 | base_model_prefix = "model" 54 | supports_gradient_checkpointing = True 55 | _no_split_modules = ["Qwen2SkipDecoderLayer"] 56 | _skip_keys_device_placement = ["past_key_values"] 57 | _supports_flash_attn_2 = True 58 | _supports_sdpa = True 59 | _supports_flex_attn = True 60 | _supports_cache_class = True 61 | _supports_quantized_cache = True 62 | _supports_static_cache = True 63 | _supports_attention_backend = True 64 | 65 | def _init_weights(self, module): 66 | std = self.config.initializer_range 67 | if isinstance(module, nn.Linear): 68 | module.weight.data.normal_(mean=0.0, std=std) 69 | if module.bias is not None: 70 | module.bias.data.zero_() 71 | elif isinstance(module, nn.Embedding): 72 | module.weight.data.normal_(mean=0.0, std=std) 73 | if module.padding_idx is not None: 74 | module.weight.data[module.padding_idx].zero_() 75 | elif isinstance(module, Qwen2RMSNorm): 76 | module.weight.data.fill_(1.0) 77 | 78 | 79 | Qwen2SkipConnectionModelBase = build_skip_connection_model(Qwen2SkipPreTrainedModel) 80 | 81 | class Qwen2SkipConnectionModel(Qwen2SkipConnectionModelBase): 82 | def _init_components(self, config): 83 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) # type: ignore 84 | self.layers = nn.ModuleList( 85 | [Qwen2SkipDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 86 | ) 87 | self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 88 | self.rotary_emb = Qwen2RotaryEmbedding(config=config) 89 | 90 | def _update_causal_mask( 91 | self, 92 | attention_mask: Union[torch.Tensor, "BlockMask"], 93 | input_tensor: torch.Tensor, 94 | cache_position: torch.Tensor, 95 | past_key_values: Cache, 96 | output_attentions: bool = False, 97 | ): 98 | if self.config._attn_implementation == "flash_attention_2": 99 | if attention_mask is not None and past_key_values is not None: 100 | is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] 101 | if is_padding_right: 102 | raise ValueError( 103 | "You are attempting to perform batched generation with padding_side='right'" 104 | " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " 105 | " call `tokenizer.padding_side = 'left'` before tokenizing the input. " 106 | ) 107 | if attention_mask is not None and 0.0 in attention_mask: 108 | return attention_mask 109 | return None 110 | if self.config._attn_implementation == "flex_attention": 111 | if isinstance(attention_mask, torch.Tensor): 112 | attention_mask = make_flex_block_causal_mask(attention_mask) 113 | return attention_mask 114 | 115 | # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in 116 | # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail 117 | # to infer the attention mask. 118 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 119 | using_static_cache = isinstance(past_key_values, StaticCache) 120 | using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) 121 | 122 | # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward 123 | if ( 124 | self.config._attn_implementation == "sdpa" 125 | and not (using_static_cache or using_sliding_window_cache) 126 | and not output_attentions 127 | ): 128 | if AttentionMaskConverter._ignore_causal_mask_sdpa( 129 | attention_mask, 130 | inputs_embeds=input_tensor, 131 | past_key_values_length=past_seen_tokens, 132 | sliding_window=self.config.sliding_window, 133 | is_training=self.training, 134 | ): 135 | return None 136 | 137 | dtype = input_tensor.dtype 138 | min_dtype = torch.finfo(dtype).min 139 | sequence_length = input_tensor.shape[1] 140 | # SlidingWindowCache or StaticCache 141 | if using_sliding_window_cache or using_static_cache: 142 | target_length = past_key_values.get_max_cache_shape() 143 | # DynamicCache or no cache 144 | else: 145 | target_length = ( 146 | attention_mask.shape[-1] 147 | if isinstance(attention_mask, torch.Tensor) 148 | else past_seen_tokens + sequence_length + 1 149 | ) 150 | 151 | # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). 152 | causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( 153 | attention_mask, 154 | sequence_length=sequence_length, 155 | target_length=target_length, 156 | dtype=dtype, 157 | cache_position=cache_position, 158 | batch_size=input_tensor.shape[0], 159 | config=self.config, 160 | past_key_values=past_key_values, 161 | ) 162 | 163 | if ( 164 | self.config._attn_implementation == "sdpa" 165 | and attention_mask is not None 166 | and attention_mask.device.type in ["cuda", "xpu", "npu"] 167 | and not output_attentions 168 | ): 169 | # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when 170 | # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. 171 | # Details: https://github.com/pytorch/pytorch/issues/110213 172 | causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) 173 | 174 | return causal_mask 175 | 176 | @staticmethod 177 | def _prepare_4d_causal_attention_mask_with_cache_position( 178 | attention_mask: torch.Tensor, 179 | sequence_length: int, 180 | target_length: int, 181 | dtype: torch.dtype, 182 | cache_position: torch.Tensor, 183 | batch_size: int, 184 | config, 185 | past_key_values: Cache, 186 | ): 187 | """ 188 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape 189 | `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. 190 | 191 | Args: 192 | attention_mask (`torch.Tensor`): 193 | A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. 194 | sequence_length (`int`): 195 | The sequence length being processed. 196 | target_length (`int`): 197 | The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. 198 | dtype (`torch.dtype`): 199 | The dtype to use for the 4D attention mask. 200 | cache_position (`torch.Tensor`): 201 | Indices depicting the position of the input sequence tokens in the sequence. 202 | batch_size (`torch.Tensor`): 203 | Batch size. 204 | config (`Qwen2Config`): 205 | The model's configuration class 206 | past_key_values (`Cache`): 207 | The cache class that is being used currently to generate 208 | """ 209 | if attention_mask is not None and attention_mask.dim() == 4: 210 | # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. 211 | causal_mask = attention_mask 212 | else: 213 | min_dtype = torch.finfo(dtype).min 214 | causal_mask = torch.full( 215 | (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device 216 | ) 217 | diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( 218 | -1, 1 219 | ) 220 | text_config = config.get_text_config() 221 | if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: 222 | # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also 223 | # the check is needed to verify is current checkpoint was trained with sliding window or not 224 | if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: 225 | sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( 226 | cache_position.reshape(-1, 1) - text_config.sliding_window 227 | ) 228 | diagonal_attend_mask.bitwise_or_(sliding_attend_mask) 229 | causal_mask *= diagonal_attend_mask 230 | causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) 231 | if attention_mask is not None: 232 | causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit 233 | if attention_mask.shape[-1] > target_length: 234 | attention_mask = attention_mask[:, :target_length] 235 | mask_length = attention_mask.shape[-1] 236 | padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( 237 | causal_mask.device 238 | ) 239 | padding_mask = padding_mask == 0 240 | causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( 241 | padding_mask, min_dtype 242 | ) 243 | return causal_mask 244 | 245 | Qwen2SkipConnectionForCausalLM = build_skip_connection_model_for_causal_lm(Qwen2SkipPreTrainedModel, Qwen2SkipConnectionModel) -------------------------------------------------------------------------------- /src/utilities/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cuda_utils 2 | from . import logger 3 | from . import random 4 | from . import registry 5 | from . import saver 6 | from . import sys_utils -------------------------------------------------------------------------------- /src/utilities/cuda_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import socket 3 | import logging 4 | import numpy as np 5 | import time 6 | import threading 7 | from typing import Dict, List, Optional 8 | 9 | 10 | class GPUMonitor: 11 | """Monitor GPU usage during inference.""" 12 | 13 | def __init__(self, monitoring_interval: float = 0.1): 14 | self.monitoring_interval = monitoring_interval 15 | self._gpu_memory_usage = [] 16 | self._gpu_utilization = [] 17 | self._is_monitoring = False 18 | self._monitor_thread = None 19 | 20 | def _monitor_gpu(self): 21 | """Background monitoring of GPU metrics.""" 22 | try: 23 | import pynvml 24 | pynvml.nvmlInit() 25 | handle = pynvml.nvmlDeviceGetHandleByIndex(0) 26 | 27 | while self._is_monitoring: 28 | # Get memory info 29 | memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle) 30 | memory_used_mb = memory_info.used / 1024**2 31 | 32 | # Get utilization 33 | utilization = pynvml.nvmlDeviceGetUtilizationRates(handle) 34 | gpu_util = utilization.gpu 35 | 36 | self._gpu_memory_usage.append(memory_used_mb) 37 | self._gpu_utilization.append(gpu_util) 38 | 39 | time.sleep(self.monitoring_interval) 40 | 41 | except ImportError: 42 | # Fallback to torch methods if pynvml not available 43 | while self._is_monitoring: 44 | if torch.cuda.is_available(): 45 | memory_used_mb = torch.cuda.memory_allocated() / 1024**2 46 | self._gpu_memory_usage.append(memory_used_mb) 47 | time.sleep(self.monitoring_interval) 48 | 49 | def start(self): 50 | """Start GPU monitoring.""" 51 | self._is_monitoring = True 52 | self._gpu_memory_usage.clear() 53 | self._gpu_utilization.clear() 54 | self._monitor_thread = threading.Thread(target=self._monitor_gpu) 55 | self._monitor_thread.start() 56 | 57 | def stop(self): 58 | """Stop GPU monitoring.""" 59 | self._is_monitoring = False 60 | if self._monitor_thread: 61 | self._monitor_thread.join() 62 | 63 | def get_peak_usage(self) -> Dict: 64 | """Get peak GPU usage metrics.""" 65 | if not self._gpu_memory_usage: 66 | return {"peak_gpu_memory_mb": 0, "p90_gpu_utilization": 0} 67 | 68 | return { 69 | "peak_gpu_memory_mb": max(self._gpu_memory_usage), 70 | "p90_gpu_memory_mb": np.percentile(self._gpu_memory_usage, 90), 71 | "max_gpu_utilization": max(self._gpu_utilization) if self._gpu_utilization else 0, 72 | "p90_gpu_utilization": np.percentile(self._gpu_utilization, 90) if self._gpu_utilization else 0 73 | } 74 | 75 | 76 | def map_to_cuda(args, device=None, **kwargs): 77 | if isinstance(args, (list, tuple)): 78 | return [map_to_cuda(arg, device, **kwargs) for arg in args] 79 | elif isinstance(args, dict): 80 | return {k: map_to_cuda(v, device, **kwargs) for k, v in args.items()} 81 | elif isinstance(args, torch.Tensor): 82 | return args.cuda(device, **kwargs) 83 | else: 84 | raise TypeError("unsupported type for cuda migration") 85 | 86 | 87 | def map_to_list(model_params): 88 | for k in model_params.keys(): 89 | model_params[k] = model_params[k].detach().numpy().tolist() 90 | return model_params 91 | 92 | 93 | def mapping_processes_to_gpus(gpu_config, process_id, worker_number): 94 | if gpu_config == None: 95 | device = torch.device("cpu") 96 | logging.info(device) 97 | # return gpu_util_map[process_id][1] 98 | return device 99 | else: 100 | logging.info(gpu_config) 101 | gpu_util_map = {} 102 | i = 0 103 | for host, gpus_util_map_host in gpu_config.items(): 104 | for gpu_j, num_process_on_gpu in enumerate(gpus_util_map_host): 105 | for _ in range(num_process_on_gpu): 106 | gpu_util_map[i] = (host, gpu_j) 107 | i += 1 108 | logging.info("Process: %d" % (process_id)) 109 | logging.info("host: %s" % (gpu_util_map[process_id][0])) 110 | logging.info("gethostname: %s" % (socket.gethostname())) 111 | logging.info("gpu: %d" % (gpu_util_map[process_id][1])) 112 | assert i == worker_number 113 | 114 | device = torch.device( 115 | "cuda:" + str(gpu_util_map[process_id][1]) 116 | if torch.cuda.is_available() else "cpu") 117 | logging.info(device) 118 | # return gpu_util_map[process_id][1] 119 | return device 120 | 121 | 122 | def initialize_cuda_safely() -> bool: 123 | """Initialize CUDA context safely, handling common issues.""" 124 | if not torch.cuda.is_available(): 125 | return False 126 | 127 | try: 128 | # Try to initialize CUDA context 129 | torch.cuda.init() 130 | torch.cuda.empty_cache() 131 | 132 | # Test basic CUDA operations 133 | device_count = torch.cuda.device_count() 134 | print(f"CUDA initialized successfully. Found {device_count} GPU(s).") 135 | 136 | # Test memory allocation on each device 137 | for i in range(device_count): 138 | try: 139 | torch.cuda.set_device(i) 140 | # Try a small memory allocation 141 | test_tensor = torch.randn(10, device=f'cuda:{i}') 142 | del test_tensor 143 | torch.cuda.empty_cache() 144 | print(f"GPU {i} is accessible and functional.") 145 | except Exception as e: 146 | print(f"Warning: GPU {i} has issues: {e}") 147 | 148 | return True 149 | 150 | except Exception as e: 151 | print(f"CUDA initialization failed: {e}") 152 | print("Falling back to CPU mode.") 153 | return False 154 | 155 | 156 | 157 | def get_gpu_info() -> Optional[List[Dict]]: 158 | """Get GPU information if CUDA is available.""" 159 | if not torch.cuda.is_available(): 160 | return None 161 | 162 | gpu_info = [] 163 | try: 164 | # Clear any existing CUDA context issues 165 | torch.cuda.empty_cache() 166 | 167 | for i in range(torch.cuda.device_count()): 168 | try: 169 | # Set the device to ensure proper context 170 | torch.cuda.set_device(i) 171 | props = torch.cuda.get_device_properties(i) 172 | 173 | # Try to get memory info with retries 174 | mem_info = None 175 | for retry in range(3): 176 | try: 177 | mem_info = torch.cuda.mem_get_info(i) 178 | break 179 | except RuntimeError as e: 180 | if retry == 2: # Last retry 181 | print(f"Warning: Could not get memory info for GPU {i}: {e}") 182 | # Use default values if we can't get memory info 183 | mem_info = (0, props.total_memory) 184 | else: 185 | # Wait a bit and clear cache before retry 186 | time.sleep(0.1) 187 | torch.cuda.empty_cache() 188 | 189 | if mem_info: 190 | free_memory = mem_info[0] / 1024**3 # Convert to GB 191 | total_memory = mem_info[1] / 1024**3 # Convert to GB 192 | else: 193 | # Fallback to device properties 194 | free_memory = props.total_memory / 1024**3 195 | total_memory = props.total_memory / 1024**3 196 | 197 | gpu_info.append({ 198 | 'name': props.name, 199 | 'compute_capability': f"{props.major}.{props.minor}", 200 | 'total_memory': f"{total_memory:.2f}GB", 201 | 'free_memory': f"{free_memory:.2f}GB", 202 | 'multi_processor_count': props.multi_processor_count 203 | }) 204 | 205 | except RuntimeError as e: 206 | print(f"Warning: Could not get info for GPU {i}: {e}") 207 | # Add a placeholder entry 208 | gpu_info.append({ 209 | 'name': f"GPU {i} (Error accessing device)", 210 | 'compute_capability': "Unknown", 211 | 'total_memory': "Unknown", 212 | 'free_memory': "Unknown", 213 | 'multi_processor_count': "Unknown" 214 | }) 215 | 216 | except Exception as e: 217 | print(f"Warning: Could not enumerate GPUs: {e}") 218 | return None 219 | 220 | return gpu_info if gpu_info else None 221 | 222 | def setup_cuda_debugging(verbose: bool = False): 223 | """Setup CUDA debugging flags.""" 224 | try: 225 | # Always set CUDA_LAUNCH_BLOCKING for better error reporting 226 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 227 | 228 | # Clear any existing CUDA context issues 229 | if torch.cuda.is_available(): 230 | torch.cuda.empty_cache() 231 | 232 | # Try to set primary device safely 233 | try: 234 | torch.cuda.set_device(0) 235 | except RuntimeError as e: 236 | print(f"Warning: Could not set CUDA device 0: {e}") 237 | 238 | if verbose: 239 | try: 240 | # Enable CUDA memory stats with error handling 241 | torch.cuda.memory.set_per_process_memory_fraction(0.9) 242 | torch.cuda.memory._record_memory_history(max_entries=10000) 243 | except Exception as e: 244 | print(f"Warning: Could not setup CUDA memory debugging: {e}") 245 | except Exception as e: 246 | print(f"Warning: CUDA debugging setup failed: {e}") -------------------------------------------------------------------------------- /src/utilities/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC, abstractmethod 3 | from time import time 4 | 5 | 6 | class BaseLogger(ABC): 7 | def __init__(self) -> None: 8 | super().__init__() 9 | 10 | @staticmethod 11 | def time(func): 12 | def decorated(*args, **kwargs): 13 | start_time = time() 14 | out = func(*args, **kwargs) 15 | end_time = time() 16 | logging.info("aggregate time cost: %d" % (end_time - start_time)) 17 | return out 18 | 19 | return decorated 20 | 21 | @abstractmethod 22 | def log(*args, **kwargs): 23 | pass 24 | 25 | @abstractmethod 26 | def log_gradients(*args, **kwargs): 27 | pass 28 | 29 | @abstractmethod 30 | def add_scalar(*args, **kwargs): 31 | pass 32 | 33 | @abstractmethod 34 | def add_histogram(*args, **kwargs): 35 | pass 36 | 37 | @abstractmethod 38 | def add_graph(*args, **kwargs): 39 | pass 40 | 41 | 42 | try: 43 | from torch.utils.tensorboard import SummaryWriter 44 | 45 | class TBLogger(SummaryWriter, BaseLogger): 46 | def __init__(self, log_dir, comment="", max_queue=10): 47 | super().__init__(log_dir=log_dir, 48 | comment=comment, 49 | max_queue=max_queue) 50 | 51 | def log(self, *args, **kwargs): 52 | print(*args, **kwargs) 53 | 54 | def log_gradients(self, model, step, to_normalize=True): 55 | for name, param in model.named_parameters(): 56 | if to_normalize: 57 | grad = param.grad.norm() 58 | self.add_scalar("grads/"+name, grad, global_step=step) 59 | else: 60 | grad = param.grad 61 | self.add_histogram("grads/"+name, grad, global_step=step) 62 | 63 | except ImportError: 64 | UserWarning("Tensorboard not installed. No Tensorboard logging.") 65 | 66 | try: 67 | import neptune 68 | 69 | class NeptuneLogger(BaseLogger): 70 | def __init__(self, log_dir, comment="", max_queue=10): 71 | super().__init__() 72 | 73 | def log(self, *args, **kwargs): 74 | print(*args, **kwargs) 75 | 76 | def log_gradients(self, model, step, to_normalize=True): 77 | for name, param in model.named_parameters(): 78 | if to_normalize: 79 | grad = param.grad.norm() 80 | self.add_scalar("grads/"+name, grad, global_step=step) 81 | else: 82 | grad = param.grad 83 | self.add_histogram("grads/"+name, grad, global_step=step) 84 | except ImportError: 85 | UserWarning("Neptune not installed. No Neptune logging.") 86 | 87 | 88 | class NoOpLogger(BaseLogger): 89 | def __init__(self) -> None: 90 | super().__init__() 91 | 92 | def log(*args, **kwargs): 93 | pass 94 | 95 | def log_gradients(*args, **kwargs): 96 | pass 97 | 98 | def add_scalar(*args, **kwargs): 99 | pass 100 | 101 | def add_histogram(*args, **kwargs): 102 | pass 103 | 104 | def add_graph(*args, **kwargs): 105 | pass 106 | -------------------------------------------------------------------------------- /src/utilities/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | import attr 4 | from typing import Dict 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class RandomState: 10 | def __init__(self): 11 | self.random_mod_state = random.getstate() 12 | self.np_state = np.random.get_state() 13 | self.torch_cpu_state = torch.get_rng_state() 14 | self.torch_gpu_states = [ 15 | torch.cuda.get_rng_state(d) 16 | for d in range(torch.cuda.device_count()) 17 | ] 18 | 19 | def restore(self): 20 | random.setstate(self.random_mod_state) 21 | np.random.set_state(self.np_state) 22 | torch.set_rng_state(self.torch_cpu_state) 23 | for d, state in enumerate(self.torch_gpu_states): 24 | torch.cuda.set_rng_state(state, d) 25 | 26 | 27 | class RandomContext: 28 | '''Save and restore state of PyTorch, NumPy, Python RNGs.''' 29 | 30 | def __init__(self, seed=None): 31 | outside_state = RandomState() 32 | 33 | random.seed(seed) 34 | np.random.seed(seed) 35 | if seed is None: 36 | torch.manual_seed(random.randint(-sys.maxsize - 1, sys.maxsize)) 37 | else: 38 | torch.manual_seed(seed) 39 | # torch.cuda.manual_seed_all is called by torch.manual_seed 40 | self.inside_state = RandomState() 41 | 42 | outside_state.restore() 43 | 44 | self._active = False 45 | 46 | def __enter__(self): 47 | if self._active: 48 | raise Exception('RandomContext can be active only once') 49 | 50 | # Save current state of RNG 51 | self.outside_state = RandomState() 52 | # Restore saved state of RNG for this context 53 | self.inside_state.restore() 54 | self._active = True 55 | 56 | def __exit__(self, exception_type, exception_value, traceback): 57 | # Save current state of RNG 58 | self.inside_state = RandomState() 59 | # Restore state of RNG saved in __enter__ 60 | self.outside_state.restore() 61 | self.outside_state = None 62 | 63 | self._active = False 64 | 65 | 66 | @attr.s 67 | class RandomizationConfig: 68 | # Seed for RNG used in shuffling the training data. 69 | data_seed = attr.ib(default=None) 70 | # Seed for RNG used in initializing the model. 71 | init_seed = attr.ib(default=None) 72 | # Seed for RNG used in computing the model's training loss. 73 | # Only relevant with internal randomness in the model, e.g. with dropout. 74 | model_seed = attr.ib(default=None) 75 | 76 | 77 | class Reproducible(object): 78 | def __init__(self, config: Dict) -> None: 79 | self.data_random = RandomContext( 80 | config["data_seed"]) 81 | self.model_random = RandomContext( 82 | config["model_seed"]) 83 | self.init_random = RandomContext( 84 | config["init_seed"]) 85 | -------------------------------------------------------------------------------- /src/utilities/registry.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The registry class makes it easy and quick to experiment with 3 | different algorithms, model architectures and hyperparameters. 4 | We only need to decorate the class definitions with registry.load 5 | and create a yaml configuration file of all the arguments to pass. 6 | Later, if we want to change any parameter (eg. number of hidden layers, 7 | learning rate, or number of clients per round), we need not change the 8 | code but only change the parameters in yaml configuration file. 9 | for detailed explaination on the use of registry, see: 10 | github.com/NimbleEdge/EnvisEdge/blob/main/docs/Tutorial-Part-2-starting_with_nimbleedge.md 11 | ''' 12 | 13 | import collections 14 | import collections.abc 15 | import inspect 16 | import sys 17 | 18 | # a defaultdict provides default values for non-existent keys. 19 | LOOKUP_DICT = collections.defaultdict(dict) 20 | 21 | 22 | def load(kind, name): 23 | ''' 24 | A decorator to record callable object definitions 25 | for models,trainers,workers etc. 26 | Arguments 27 | ---------- 28 | kind: str 29 | Key to store in dictionary, used to specify the 30 | kind of object (eg. model, trainer). 31 | name: str 32 | Sub-key under kind key, used to specify name of 33 | of the object definition. 34 | Returns 35 | ---------- 36 | callable: 37 | Decorator function to store object definition. 38 | Examples 39 | ---------- 40 | >>> @registry.load('model', 'dlrm') 41 | ... class DLRM_Net(nn.Module): # This class definition gets recorded 42 | ... def __init__(self, arg): 43 | ... self.arg = arg 44 | ''' 45 | 46 | assert kind != "class_map", "reserved keyword for kind \"class_map\"" 47 | registry = LOOKUP_DICT[kind] 48 | class_ref = LOOKUP_DICT["class_map"] 49 | 50 | def decorator(obj): 51 | if name in registry: 52 | raise LookupError('{} already present'.format(name, kind)) 53 | registry[name] = obj 54 | class_ref[obj.__module__ + "." + obj.__name__] = obj 55 | return obj 56 | return decorator 57 | 58 | 59 | def lookup(kind, name): 60 | ''' 61 | Returns the callable object definition stored in registry. 62 | Arguments 63 | ---------- 64 | kind: str 65 | Key to search in dictionary of registry. 66 | name: str 67 | Sub-key to search under kind key in dictionary 68 | of registry. 69 | Returns 70 | ---------- 71 | callable: 72 | Object definition stored in registry under key kind 73 | and sub-key name. 74 | Examples 75 | ---------- 76 | >>> @registry.load('model', 'dlrm') 77 | ... class DLRM_Net(nn.Module): # This class definition gets recorded 78 | ... def __init__(self, arg): 79 | ... self.arg = arg 80 | >>> model = lookup('model', 'dlrm') # loads model class from registry 81 | >>> model # model is a DLRM_Net object 82 | __main__.DLRM_Net 83 | ''' 84 | 85 | # check if 'name' argument is a dictionary. 86 | # if yes, load the value under key 'name'. 87 | if isinstance(name, collections.abc.Mapping): 88 | name = name['name'] 89 | 90 | if kind not in LOOKUP_DICT: 91 | raise KeyError('Nothing registered under "{}"'.format(kind)) 92 | return LOOKUP_DICT[kind][name] 93 | 94 | 95 | def construct(kind, config, unused_keys=(), **kwargs): 96 | ''' 97 | Returns an object instance by loading definition from registry, 98 | and arguments from configuration file. 99 | Arguments 100 | ---------- 101 | kind: str 102 | Key to search in dictionary of registry. 103 | config: dict 104 | Configuration dictionary loaded from yaml file 105 | unused_keys: tuple 106 | Keys for values that are not passed as arguments to 107 | insantiate the object but are still present in config. 108 | **kwargs: dict, optional 109 | Extra arguments to pass. 110 | Returns 111 | ---------- 112 | object: 113 | Constructed object using the parameters passed in config and \**kwargs. 114 | Examples 115 | ---------- 116 | >>> @registry.load('model', 'dlrm') 117 | ... class DLRM_Net(nn.Module): # This class definition gets recorded 118 | ... def __init__(self, arg): 119 | ... self.arg = arg 120 | >>> model = construct('model', 'drlm', (), arg = 5) 121 | >>> model.arg # model is a DLRM_Net object with arg = 5 122 | 5 123 | ''' 124 | 125 | # check if 'config' argument is a string, 126 | # if yes, make it a dictionary. 127 | if isinstance(config, str): 128 | config = {'name': config} 129 | return instantiate( 130 | lookup(kind, config), 131 | config, 132 | unused_keys + ('name',), 133 | **kwargs) 134 | 135 | 136 | def instantiate(callable, config, unused_keys=(), **kwargs): 137 | ''' 138 | Instantiates an object after verifying the parameters. 139 | Arguments 140 | ---------- 141 | callable: callable 142 | Definition of object to be instantiated. 143 | config: dict 144 | Arguments to construct the object. 145 | unused_keys: tuple 146 | Keys for values that are not passed as arguments to 147 | insantiate the object but are still present in config. 148 | **kwargs: dict, optional 149 | Extra arguments to pass. 150 | Returns 151 | ---------- 152 | object: 153 | Instantiated object by the parameters passed in config and \**kwargs. 154 | Examples 155 | ---------- 156 | >>> @registry.load('model', 'dlrm') 157 | ... class DLRM_Net(nn.Module): # This class definition gets recorded 158 | ... def __init__(self, arg): 159 | ... self.arg = arg 160 | >>> config = {'name': 'dlrm', 'arg': 5} # loaded from a yaml config file 161 | >>> call = lookup('model', 'dlrm') # Loads the class definition 162 | >>> model = instantiate(call, config, ('name')) 163 | >>> model.arg # model is a DRLM_Net object with arg = 5 164 | 5 165 | ''' 166 | 167 | # merge config arguments and kwargs in a single dictionary. 168 | merged = {**config, **kwargs} 169 | 170 | # check if callable has valid parameters. 171 | signature = inspect.signature(callable) 172 | for name, param in signature.parameters.items(): 173 | if param.kind in (inspect.Parameter.POSITIONAL_ONLY, 174 | inspect.Parameter.VAR_POSITIONAL): 175 | raise ValueError('Unsupported kind for param {}: {}'.format( 176 | name, param.kind)) 177 | 178 | if any(param.kind == inspect.Parameter.VAR_KEYWORD 179 | for param in signature.parameters.values()): 180 | return callable(**merged) 181 | 182 | # check and warn if config has unneccassary arguments that 183 | # callable does not require and are not mentioned in unused_keys. 184 | missing = {} 185 | for key in list(merged.keys()): 186 | if key not in signature.parameters: 187 | if key not in unused_keys: 188 | missing[key] = merged[key] 189 | merged.pop(key) 190 | if missing: 191 | print('WARNING {}: superfluous {}'.format( 192 | callable, missing), file=sys.stderr) 193 | return callable(**merged) 194 | 195 | -------------------------------------------------------------------------------- /src/utilities/saver.py: -------------------------------------------------------------------------------- 1 | """Tools to save/restore model from checkpoints.""" 2 | 3 | import shutil 4 | import os 5 | import re 6 | 7 | import torch 8 | 9 | CHECKPOINT_PATTERN = re.compile('^model_checkpoint-(\d+)$') 10 | 11 | 12 | class ArgsDict(dict): 13 | 14 | def __init__(self, **kwargs): 15 | super(ArgsDict, self).__init__() 16 | for key, value in kwargs.items(): 17 | self[key] = value 18 | self.__dict__ = self 19 | 20 | 21 | def create_link(original, link_name): 22 | if os.path.islink(link_name): 23 | os.unlink(link_name) 24 | try: 25 | os.symlink(os.path.basename(original), link_name) 26 | except OSError: 27 | shutil.copy2(original, link_name) 28 | 29 | 30 | def load_checkpoint(model, 31 | optimizer, 32 | model_dir, 33 | map_location=None, 34 | step=None): 35 | path = os.path.join(model_dir, 'model_checkpoint') 36 | if step is not None: 37 | path += '-{:08d}'.format(step) 38 | if os.path.exists(path): 39 | print("Loading model from %s" % path) 40 | checkpoint = torch.load(path, map_location=map_location) 41 | model.load_state_dict(checkpoint['model'], strict=False) 42 | optimizer.load_state_dict(checkpoint['optimizer']) 43 | return checkpoint.get('step', 0), checkpoint.get('epoch', 0) 44 | return 0, 0 45 | 46 | 47 | def load_and_map_checkpoint(model, model_dir, remap): 48 | path = os.path.join(model_dir, 'model_checkpoint') 49 | print("Loading parameters %s from %s" % (remap.keys(), model_dir)) 50 | checkpoint = torch.load(path) 51 | new_state_dict = model.state_dict() 52 | for name, value in remap.items(): 53 | # TODO: smarter mapping. 54 | new_state_dict[name] = checkpoint['model'][value] 55 | model.load_state_dict(new_state_dict) 56 | 57 | 58 | def save_checkpoint(model, 59 | optimizer, 60 | step, 61 | epoch, 62 | model_dir, 63 | is_best, 64 | ignore=[], 65 | keep_every_n=10000000): 66 | if not os.path.exists(model_dir): 67 | os.makedirs(model_dir) 68 | path_without_step = os.path.join(model_dir, 'model_checkpoint') 69 | step_padded = format(step, '08d') 70 | state_dict = model.state_dict() 71 | if ignore: 72 | for key in state_dict.keys(): 73 | for item in ignore: 74 | if key.startswith(item): 75 | state_dict.pop(key) 76 | path_with_step = '{}-{}'.format(path_without_step, step_padded) 77 | torch.save({ 78 | 'model': state_dict, 79 | 'optimizer': optimizer.state_dict(), 80 | 'epoch': epoch, 81 | 'step': step 82 | }, path_with_step) 83 | create_link(path_with_step, path_without_step) 84 | create_link(path_with_step, os.path.join(model_dir, 'best_checkpoint')) 85 | 86 | # Cull old checkpoints. 87 | if keep_every_n is not None: 88 | all_checkpoints = [] 89 | for name in os.listdir(model_dir): 90 | m = CHECKPOINT_PATTERN.match(name) 91 | if m is None or name == os.path.basename(path_with_step): 92 | continue 93 | checkpoint_step = int(m.group(1)) 94 | all_checkpoints.append((checkpoint_step, name)) 95 | all_checkpoints.sort() 96 | 97 | last_step = float('-inf') 98 | for checkpoint_step, name in all_checkpoints: 99 | if checkpoint_step - last_step >= keep_every_n: 100 | last_step = checkpoint_step 101 | continue 102 | os.unlink(os.path.join(model_dir, name)) 103 | 104 | 105 | class Saver(object): 106 | """Class to manage save and restore for the model and optimizer.""" 107 | 108 | def __init__(self, model, optimizer, keep_every_n=None): 109 | self._model = model 110 | self._optimizer = optimizer 111 | self._keep_every_n = keep_every_n 112 | 113 | def restore(self, model_dir=None, map_location=None, step=None): 114 | """Restores model and optimizer from given directory. 115 | Returns 116 | Last training step for the model restored. 117 | """ 118 | if model_dir is None: 119 | return 0, 0 120 | last_step, epoch = load_checkpoint( 121 | self._model, self._optimizer, model_dir, map_location, step) 122 | return last_step, epoch 123 | 124 | def save(self, model_dir, step, epoch, is_best=False): 125 | """Saves model and optimizer to given directory. 126 | Args: 127 | model_dir: Model directory to save. If None ignore. 128 | step: Current training step. 129 | """ 130 | if model_dir is None: 131 | return 132 | save_checkpoint(self._model, self._optimizer, step, epoch, model_dir, 133 | keep_every_n=self._keep_every_n, is_best=is_best) 134 | 135 | def restore_part(self, other_model_dir, remap): 136 | """Restores part of the model from other directory. 137 | Useful to initialize part of the model with another pretrained model. 138 | Args: 139 | other_model_dir: Model directory to load from. 140 | remap: dict, remapping current parameters to the other model's. 141 | """ 142 | load_and_map_checkpoint(self._model, other_model_dir, remap) -------------------------------------------------------------------------------- /src/utilities/sys_utils.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import psutil 3 | import argparse 4 | import torch 5 | from typing import Dict 6 | from .cuda_utils import get_gpu_info 7 | 8 | def get_system_info() -> Dict: 9 | """Get system information including CPU and RAM details.""" 10 | cpu_info = { 11 | 'processor': platform.processor(), 12 | 'physical_cores': psutil.cpu_count(logical=False), 13 | 'total_cores': psutil.cpu_count(logical=True), 14 | 'max_frequency': f"{psutil.cpu_freq().max:.0f}MHz" if psutil.cpu_freq() else "Unknown", 15 | 'current_frequency': f"{psutil.cpu_freq().current:.0f}MHz" if psutil.cpu_freq() else "Unknown" 16 | } 17 | 18 | memory = psutil.virtual_memory() 19 | ram_info = { 20 | 'total': f"{memory.total / (1024**3):.2f}GB", 21 | 'available': f"{memory.available / (1024**3):.2f}GB", 22 | 'used_percent': f"{memory.percent}%" 23 | } 24 | 25 | return { 26 | 'system': platform.system(), 27 | 'release': platform.release(), 28 | 'version': platform.version(), 29 | 'machine': platform.machine(), 30 | 'cpu': cpu_info, 31 | 'ram': ram_info 32 | } 33 | 34 | 35 | def print_system_info(args: argparse.Namespace) -> None: 36 | """Print system configuration information.""" 37 | print("\nSystem Configuration:") 38 | print("-" * 50) 39 | system_info = get_system_info() 40 | print(f"OS: {system_info['system']} {system_info['release']}") 41 | print(f"CPU: {system_info['cpu']['processor']}") 42 | print(f"Physical cores: {system_info['cpu']['physical_cores']}") 43 | print(f"Total cores: {system_info['cpu']['total_cores']}") 44 | print(f"Max CPU frequency: {system_info['cpu']['max_frequency']}") 45 | print(f"Current CPU frequency: {system_info['cpu']['current_frequency']}") 46 | print(f"RAM: Total={system_info['ram']['total']}, Available={system_info['ram']['available']} ({system_info['ram']['used_percent']} used)") 47 | 48 | if args.device == 'cuda': 49 | print("\nGPU Configuration:") 50 | print("-" * 50) 51 | gpu_info = get_gpu_info() 52 | for i, gpu in enumerate(gpu_info or []): 53 | print(f"\nGPU {i}: {gpu['name']}") 54 | print(f"Compute capability: {gpu['compute_capability']}") 55 | print(f"Total memory: {gpu['total_memory']}") 56 | print(f"Free memory: {gpu['free_memory']}") 57 | print(f"Multi processors: {gpu['multi_processor_count']}") 58 | 59 | print("\nPyTorch version:", torch.__version__) 60 | print("CUDA version:", torch.version.cuda if torch.cuda.is_available() else "N/A") 61 | print("-" * 50) 62 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Training script for sparsity predictors using datasets generated by generate_dataset.py. 4 | 5 | This script trains predictors to identify which MLP neurons will be most activated 6 | based on the hidden states before each MLP layer. Uses the last token representations 7 | from the generated datasets. 8 | 9 | Usage: 10 | # Start fresh training for single layer 11 | python train.py \ 12 | --config meta-llama/Llama-2-7b-hf \ 13 | --dataset_dir ./data/c4 \ 14 | --output_dir ./trained_predictors \ 15 | --layer_indices 0 \ 16 | --batch_size 32 \ 17 | --num_epochs 10 \ 18 | --learning_rate 1e-5 19 | 20 | # Start fresh training for multiple layers 21 | python train.py \ 22 | --config meta-llama/Llama-2-7b-hf \ 23 | --dataset_dir ./data/c4 \ 24 | --output_dir ./trained_predictors \ 25 | --layer_indices 0 1 2 3 \ 26 | --batch_size 32 \ 27 | --num_epochs 10 \ 28 | --learning_rate 1e-5 29 | 30 | # Train predictors for ALL layers automatically 31 | python train.py \ 32 | --config meta-llama/Llama-2-7b-hf \ 33 | --dataset_dir ./data/c4 \ 34 | --output_dir ./trained_predictors \ 35 | --layer_indices all \ 36 | --batch_size 32 \ 37 | --num_epochs 10 \ 38 | --learning_rate 1e-5 39 | 40 | # Train predictors with hyperparameter grid (different LoRA sizes) 41 | python train.py \ 42 | --config meta-llama/Llama-2-7b-hf \ 43 | --dataset_dir ./data/c4 \ 44 | --output_dir ./trained_predictors \ 45 | --layer_indices 0 1 2 \ 46 | --lora_sizes 4.0 10.0 20.0 30.0 \ 47 | --batch_size 32 \ 48 | --num_epochs 10 \ 49 | --learning_rate 1e-5 50 | 51 | # Resume from latest checkpoint 52 | python train.py \ 53 | --config meta-llama/Llama-2-7b-hf \ 54 | --dataset_dir ./data/c4 \ 55 | --output_dir ./trained_predictors \ 56 | --layer_indices 0 1 2 \ 57 | --batch_size 32 \ 58 | --num_epochs 10 \ 59 | --learning_rate 1e-5 \ 60 | --resume_from_checkpoint 61 | 62 | # Resume from specific checkpoint 63 | python train.py \ 64 | --config meta-llama/Llama-2-7b-hf \ 65 | --dataset_dir ./data/c4 \ 66 | --output_dir ./trained_predictors \ 67 | --layer_indices 0 \ 68 | --batch_size 32 \ 69 | --num_epochs 10 \ 70 | --learning_rate 1e-5 \ 71 | --resume_from_checkpoint \ 72 | --checkpoint_path ./trained_predictors/checkpoint_layer_0_step_5000.pt 73 | 74 | # Resume only from best-performing lora size per layer 75 | python train.py \ 76 | --config meta-llama/Llama-2-7b-hf \ 77 | --dataset_dir ./data/c4 \ 78 | --output_dir ./trained_predictors \ 79 | --layer_indices 0 \ 80 | --batch_size 32 \ 81 | --num_epochs 10 \ 82 | --learning_rate 1e-5 \ 83 | --resume_from_checkpoint \ 84 | --load_best_only 85 | """ 86 | 87 | import argparse 88 | import logging 89 | import os 90 | import time 91 | 92 | import torch 93 | 94 | from transformers import AutoConfig 95 | import wandb 96 | 97 | from transformers.trainer_utils import set_seed 98 | 99 | from src.trainer import MultiLayerPredictorTrainer 100 | 101 | # Setup logging 102 | logging.basicConfig(level=logging.INFO) 103 | logger = logging.getLogger(__name__) 104 | 105 | 106 | def main(): 107 | parser = argparse.ArgumentParser(description="Train sparsity predictors from pre-generated datasets") 108 | parser.add_argument("--config", type=str, required=True, help="Path to model config file") 109 | parser.add_argument("--dataset_dir", type=str, required=True, help="Directory containing dataset.csv and arrays/") 110 | parser.add_argument("--output_dir", type=str, required=True, help="Output directory for trained models") 111 | parser.add_argument("--layer_indices", nargs='+', required=True, help="Which layers to train predictors for (can specify multiple layer numbers or 'all' for all layers)") 112 | parser.add_argument("--batch_size", type=int, default=64, help="Training batch size") 113 | parser.add_argument("--num_epochs", type=int, default=10, help="Number of training epochs") 114 | parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate") 115 | parser.add_argument("--lora_size", type=int, default=None, help="LoRA bottleneck size (default: 4% of intermediate_size)") 116 | parser.add_argument("--lora_sizes", nargs='+', type=float, default=None, help="LoRA sizes as percentages (e.g., 4.0 10.0 20.0 30.0)") 117 | parser.add_argument("--val_split", type=float, default=0.1, help="Validation split fraction") 118 | parser.add_argument("--cache_size", type=int, default=50, help="Number of .npz chunk files to cache in memory") 119 | parser.add_argument("--load_full_dataset", action="store_true", help="Load full dataset into memory at initialization (faster but uses more memory)") 120 | parser.add_argument("--checkpoint_save_interval", type=int, default=20000, help="Save checkpoint every N steps") 121 | parser.add_argument("--resume_from_checkpoint", action="store_true", help="Resume training from the latest checkpoint") 122 | parser.add_argument("--checkpoint_path", type=str, default=None, help="Specific checkpoint path to resume from (optional)") 123 | parser.add_argument("--load_best_only", action="store_true", help="Resume training only from best performing lora size for each layer.") 124 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 125 | parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases logging") 126 | parser.add_argument("--wandb_project", type=str, default="llama-skip-predictors", help="W&B project name") 127 | parser.add_argument("--wandb_entity", type=str, default="llama-skip-predictors", help="W&B entity name") 128 | parser.add_argument("--device", type=str, default="auto", help="Device to use (auto, cpu, cuda)") 129 | 130 | 131 | args = parser.parse_args() 132 | 133 | # Set seed 134 | set_seed(args.seed) 135 | 136 | # Setup device 137 | if args.device == "auto": 138 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 139 | else: 140 | device = torch.device(args.device) 141 | 142 | logger.info(f"Using device: {device}") 143 | 144 | # Load config 145 | config = AutoConfig.from_pretrained(args.config) 146 | 147 | # Handle 'all' option for layer_indices 148 | if len(args.layer_indices) == 1 and args.layer_indices[0] == "all": 149 | # Get number of layers from config 150 | num_layers = getattr(config, 'num_hidden_layers', None) or getattr(config, 'n_layer', None) or getattr(config, 'num_layers', None) 151 | if num_layers is None: 152 | raise ValueError("Could not determine number of layers from config. Please specify layer indices explicitly.") 153 | args.layer_indices = list(range(num_layers)) 154 | logger.info(f"Training predictors for all {num_layers} layers: {args.layer_indices}") 155 | else: 156 | # Convert to integers 157 | try: 158 | args.layer_indices = [int(idx) for idx in args.layer_indices] 159 | except ValueError: 160 | raise ValueError("Layer indices must be integers or 'all'") 161 | logger.info(f"Training predictors for specified layers: {args.layer_indices}") 162 | 163 | logger.info(f"Model config: hidden_size={config.hidden_size}, intermediate_size={config.intermediate_size}") 164 | 165 | # Initialize wandb for multi-layer training 166 | if args.use_wandb: 167 | wandb.init( 168 | entity=args.wandb_entity, 169 | project=args.wandb_project, 170 | config=vars(args), 171 | name=f"predictor-layers-{'-'.join(map(str, args.layer_indices))}-training-{int(time.time())}" 172 | ) 173 | 174 | # Initialize multi-layer trainer 175 | trainer = MultiLayerPredictorTrainer( 176 | config=config, 177 | layer_indices=args.layer_indices, 178 | device=device, 179 | lora_size=args.lora_size, 180 | lora_sizes=args.lora_sizes 181 | ) 182 | 183 | # Train all layers 184 | trainer.train_all_layers( 185 | dataset_dir=args.dataset_dir, 186 | num_epochs=args.num_epochs, 187 | batch_size=args.batch_size, 188 | learning_rate=args.learning_rate, 189 | val_split=args.val_split, 190 | cache_size=args.cache_size, 191 | load_full_dataset=args.load_full_dataset, 192 | use_wandb=args.use_wandb, 193 | save_dir=args.output_dir, 194 | save_interval=args.checkpoint_save_interval, 195 | resume_from_checkpoint=args.resume_from_checkpoint, 196 | load_best_only= args.load_best_only, 197 | checkpoint_path=args.checkpoint_path, 198 | seed=args.seed 199 | ) 200 | 201 | if args.use_wandb: 202 | wandb.finish() 203 | 204 | logger.info(f"Training completed for all layers: {args.layer_indices}") 205 | 206 | 207 | if __name__ == "__main__": 208 | main() -------------------------------------------------------------------------------- /train_parallel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Parallel training script for sparsity predictors 4 | # Divides layer training across multiple parallel train.py jobs 5 | # 6 | # Usage: 7 | # ./train_parallel.sh --layers-per-job 4 --config meta-llama/Llama-2-7b-hf \ 8 | # --dataset_dir ./data/c4 --output_dir ./trained_predictors \ 9 | # --layer_indices 0 1 2 3 4 5 6 7 8 9 10 11 \ 10 | # --batch_size 32 --num_epochs 10 --learning_rate 1e-5 11 | # 12 | # Example with LoRA grid: 13 | # ./train_parallel.sh --layers-per-job 3 --num_layers 32 --config meta-llama/Llama-2-7b-hf \ 14 | # --dataset_dir ./data/c4 --output_dir ./trained_predictors \ 15 | # --layer_indices all --lora_sizes 4.0 10.0 20.0 30.0 \ 16 | # --batch_size 32 --num_epochs 10 --learning_rate 1e-5 17 | 18 | set -e # Exit on any error 19 | 20 | # Default values 21 | LAYERS_PER_JOB=4 22 | TRAIN_ARGS=() 23 | LAYER_INDICES=() 24 | CONFIG="" 25 | NUM_LAYERS="" 26 | PYTHON_CMD="python" 27 | 28 | # Colors for output 29 | RED='\033[0;31m' 30 | GREEN='\033[0;32m' 31 | YELLOW='\033[1;33m' 32 | BLUE='\033[0;34m' 33 | NC='\033[0m' # No Color 34 | 35 | # Function to print colored output 36 | print_info() { 37 | echo -e "${BLUE}[INFO]${NC} $1" 38 | } 39 | 40 | print_success() { 41 | echo -e "${GREEN}[SUCCESS]${NC} $1" 42 | } 43 | 44 | print_warning() { 45 | echo -e "${YELLOW}[WARNING]${NC} $1" 46 | } 47 | 48 | print_error() { 49 | echo -e "${RED}[ERROR]${NC} $1" 50 | } 51 | 52 | # Function to show usage 53 | show_usage() { 54 | cat << EOF 55 | Usage: $0 --layers-per-job N [train.py arguments...] 56 | 57 | Parallel training script that divides layer training across multiple train.py jobs. 58 | 59 | Required arguments: 60 | --layers-per-job N Number of layers to train in each parallel job 61 | 62 | Script-specific arguments: 63 | --num_layers N Total number of layers (required when --layer_indices is 'all') 64 | 65 | All other arguments are passed directly to train.py. Key arguments include: 66 | --config MODEL_PATH Path to model config (required for train.py) 67 | --dataset_dir PATH Path to dataset directory (required for train.py) 68 | --output_dir PATH Output directory for trained models (required for train.py) 69 | --layer_indices LAYERS Layer indices to train (space-separated numbers or 'all') 70 | --lora_sizes SIZES LoRA sizes as percentages (e.g., 4.0 10.0 20.0 30.0) 71 | --batch_size N Training batch size 72 | --num_epochs N Number of training epochs 73 | --learning_rate RATE Learning rate 74 | --use_wandb Enable Weights & Biases logging 75 | 76 | Examples: 77 | # Train 12 layers with 4 layers per job (3 parallel jobs) 78 | $0 --layers-per-job 4 --config meta-llama/Llama-2-7b-hf \\ 79 | --dataset_dir ./data/c4 --output_dir ./trained_predictors \\ 80 | --layer_indices 0 1 2 3 4 5 6 7 8 9 10 11 \\ 81 | --batch_size 32 --num_epochs 10 --learning_rate 1e-5 82 | 83 | # Train all 32 layers with LoRA grid using 3 layers per job 84 | $0 --layers-per-job 3 --num_layers 32 --config meta-llama/Llama-2-7b-hf \\ 85 | --dataset_dir ./data/c4 --output_dir ./trained_predictors \\ 86 | --layer_indices all --lora_sizes 4.0 10.0 20.0 30.0 \\ 87 | --batch_size 32 --num_epochs 10 --learning_rate 1e-5 --use_wandb 88 | 89 | EOF 90 | } 91 | 92 | # Parse arguments 93 | while [[ $# -gt 0 ]]; do 94 | case $1 in 95 | --layers-per-job) 96 | LAYERS_PER_JOB="$2" 97 | shift 2 98 | ;; 99 | --num_layers) 100 | NUM_LAYERS="$2" 101 | shift 2 102 | ;; 103 | --layer_indices) 104 | shift 105 | # Collect all layer indices until next flag or end 106 | while [[ $# -gt 0 && ! "$1" =~ ^-- ]]; do 107 | LAYER_INDICES+=("$1") 108 | shift 109 | done 110 | # Add to train args 111 | TRAIN_ARGS+=(--layer_indices "${LAYER_INDICES[@]}") 112 | ;; 113 | --config) 114 | CONFIG="$2" 115 | TRAIN_ARGS+=("$1" "$2") 116 | shift 2 117 | ;; 118 | --help|-h) 119 | show_usage 120 | exit 0 121 | ;; 122 | *) 123 | # Pass through all other arguments to train.py 124 | TRAIN_ARGS+=("$1") 125 | shift 126 | ;; 127 | esac 128 | done 129 | 130 | # Validate required arguments 131 | if [[ -z "$LAYERS_PER_JOB" ]] || [[ ! "$LAYERS_PER_JOB" =~ ^[0-9]+$ ]] || [[ "$LAYERS_PER_JOB" -lt 1 ]]; then 132 | print_error "Invalid --layers-per-job value: '$LAYERS_PER_JOB'. Must be a positive integer." 133 | show_usage 134 | exit 1 135 | fi 136 | 137 | if [[ -z "$CONFIG" ]]; then 138 | print_error "Missing required argument: --config" 139 | show_usage 140 | exit 1 141 | fi 142 | 143 | if [[ ${#LAYER_INDICES[@]} -eq 0 ]]; then 144 | print_error "Missing required argument: --layer_indices" 145 | show_usage 146 | exit 1 147 | fi 148 | 149 | print_info "Parallel training configuration:" 150 | print_info " Layers per job: $LAYERS_PER_JOB" 151 | print_info " Model config: $CONFIG" 152 | print_info " Layer indices: ${LAYER_INDICES[*]}" 153 | 154 | # Handle 'all' layers by generating the list from num_layers 155 | if [[ ${#LAYER_INDICES[@]} -eq 1 && "${LAYER_INDICES[0]}" == "all" ]]; then 156 | if [[ -z "$NUM_LAYERS" ]] || [[ ! "$NUM_LAYERS" =~ ^[0-9]+$ ]] || [[ "$NUM_LAYERS" -lt 1 ]]; then 157 | print_error "When --layer_indices is 'all', you must specify --num_layers with a positive integer" 158 | print_error "Example: --num_layers 32 --layer_indices all" 159 | exit 1 160 | fi 161 | 162 | print_info "Generating layer indices for 'all' option with $NUM_LAYERS layers..." 163 | 164 | # Generate layer indices from 0 to NUM_LAYERS-1 165 | LAYER_INDICES=() 166 | for ((i=0; i "$log_file" 2>&1; then 227 | print_success "$job_name completed successfully" 228 | return 0 229 | else 230 | print_error "$job_name failed! Check log: $log_file" 231 | return 1 232 | fi 233 | } 234 | 235 | # Start all jobs in parallel 236 | print_info "Starting $NUM_JOBS parallel training jobs..." 237 | JOB_PIDS=() 238 | FAILED_JOBS=() 239 | 240 | for ((job_id=0; job_id