├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── pull_request_template.md ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks ├── llama1b │ ├── benchmark_cpu_l1.log │ ├── benchmark_gpu.log │ └── summary.json └── llama3b │ ├── benchmark_cpu_l3.log │ └── summary.json ├── configs ├── base.yml ├── llama_skip_causal_1b.json ├── llama_skip_causal_1b_predictor_training.json ├── llama_skip_causal_3b.json └── llama_skip_causal_3b_predictor_training.json ├── measure_contextual_sparsity.py ├── pyproject.toml ├── requirements.txt ├── run_benchmark.py ├── setup.py ├── sparse_transformers ├── CMakeLists.txt ├── __init__.py └── csrc │ ├── approx_topk.h │ ├── sparse_mlp_cuda.cu │ ├── sparse_mlp_op.cpp │ └── weight_cache.h ├── src ├── __init__.py ├── models │ ├── __init__.py │ └── llama │ │ ├── __init__.py │ │ ├── configuration_llama_skip.py │ │ └── modelling_llama_skip.py ├── trainer.py └── utilities │ ├── __init__.py │ ├── cuda_utils.py │ ├── logger.py │ ├── random.py │ ├── registry.py │ └── saver.py └── train_predictors.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Report a bug to help us improve sparse transformers 4 | title: '[BUG] ' 5 | labels: bug 6 | assignees: '' 7 | --- 8 | 9 | **Describe the Bug** 10 | Include description of what the bug is and add steps to reproduce this behavior. 11 | 12 | **Expected Behavior** 13 | A clear and concise description of what you expected to happen. 14 | 15 | **Additional Information** 16 | Add any other context about the problem here. 17 | 18 | **Possible Solution** 19 | If you have suggestions on how to fix the issue, please describe it here. 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Suggest an idea to improve sparse transformers 4 | title: '[FEATURE] ' 5 | labels: enhancement 6 | assignees: '' 7 | --- 8 | 9 | **Describe the feature request** 10 | 11 | **Describe the solution you'd like** 12 | 13 | **Describe alternatives you've considered** 14 | 15 | **Additional context** 16 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | Please provide a description of this PR. 3 | 4 | Fixes # (issue) 5 | 6 | ## Checklist: 7 | - [ ] I have added tests that prove my fix is effective or that my feature works 8 | - [ ] Has user-facing changes. This may include API or behavior changes and performance improvments, etc 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | **/__pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | llama.cpp/ 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | .cursorrules 24 | wandb/ 25 | # CUDA 26 | *.i 27 | *.ii 28 | *.gpu 29 | *.ptx 30 | *.cubin 31 | *.fatbin 32 | *.o 33 | *.obj 34 | *.pkl 35 | *.png 36 | # IDE specific files 37 | .idea/ 38 | .vscode/ 39 | *.swp 40 | *.swo 41 | .project 42 | .pydevproject 43 | .settings/ 44 | .vs/ 45 | 46 | # Environment 47 | .env 48 | .venv 49 | env/ 50 | venv/ 51 | ENV/ 52 | env.bak/ 53 | venv.bak/ 54 | 55 | # Distribution / packaging 56 | .Python 57 | build/ 58 | develop-eggs/ 59 | dist/ 60 | downloads/ 61 | eggs/ 62 | .eggs/ 63 | lib/ 64 | lib64/ 65 | parts/ 66 | sdist/ 67 | var/ 68 | wheels/ 69 | *.egg-info/ 70 | .installed.cfg 71 | *.egg 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # Testing 77 | htmlcov/ 78 | .tox/ 79 | .coverage 80 | .coverage.* 81 | .cache 82 | nosetests.xml 83 | coverage.xml 84 | *.cover 85 | .hypothesis/ 86 | 87 | # Logs and databases 88 | *.sqlite 89 | *.db 90 | 91 | # OS generated files 92 | .DS_Store 93 | .DS_Store? 94 | ._* 95 | .Spotlight-V100 96 | .Trashes 97 | ehthumbs.db 98 | Thumbs.db 99 | build.log 100 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official email address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at . 63 | All complaints will be reviewed and investigated promptly and fairly. 64 | 65 | All community leaders are obligated to respect the privacy and security of the 66 | reporter of any incident. 67 | 68 | ## Enforcement Guidelines 69 | 70 | Community leaders will follow these Community Impact Guidelines in determining 71 | the consequences for any action they deem in violation of this Code of Conduct: 72 | 73 | ### 1. Correction 74 | 75 | **Community Impact**: Use of inappropriate language or other behavior deemed 76 | unprofessional or unwelcome in the community. 77 | 78 | **Consequence**: A private, written warning from community leaders, providing 79 | clarity around the nature of the violation and an explanation of why the 80 | behavior was inappropriate. A public apology may be requested. 81 | 82 | ### 2. Warning 83 | 84 | **Community Impact**: A violation through a single incident or series of 85 | actions. 86 | 87 | **Consequence**: A warning with consequences for continued behavior. No 88 | interaction with the people involved, including unsolicited interaction with 89 | those enforcing the Code of Conduct, for a specified period of time. This 90 | includes avoiding interactions in community spaces as well as external channels 91 | like social media. Violating these terms may lead to a temporary or permanent 92 | ban. 93 | 94 | ### 3. Temporary Ban 95 | 96 | **Community Impact**: A serious violation of community standards, including 97 | sustained inappropriate behavior. 98 | 99 | **Consequence**: A temporary ban from any sort of interaction or public 100 | communication with the community for a specified period of time. No public or 101 | private interaction with the people involved, including unsolicited interaction 102 | with those enforcing the Code of Conduct, is allowed during this period. 103 | Violating these terms may lead to a permanent ban. 104 | 105 | ### 4. Permanent Ban 106 | 107 | **Community Impact**: Demonstrating a pattern of violation of community 108 | standards, including sustained inappropriate behavior, harassment of an 109 | individual, or aggression toward or disparagement of classes of individuals. 110 | 111 | **Consequence**: A permanent ban from any sort of public interaction within the 112 | community. 113 | 114 | ## Attribution 115 | 116 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 117 | version 2.1, available at 118 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 119 | 120 | Community Impact Guidelines were inspired by 121 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 122 | 123 | For answers to common questions about this code of conduct, see the FAQ at 124 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 125 | [https://www.contributor-covenant.org/translations][translations]. 126 | 127 | [homepage]: https://www.contributor-covenant.org 128 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 129 | [Mozilla CoC]: https://github.com/mozilla/diversity 130 | [FAQ]: https://www.contributor-covenant.org/faq 131 | [translations]: https://www.contributor-covenant.org/translations 132 | 133 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Sparse Transformers 2 | 3 | Thank you for your interest in contributing to Sparse Transformers! This document provides guidelines and instructions for contributing to this project. 4 | 5 | ## Development Workflow 6 | 7 | We follow a fork and pull request workflow for all contributions. Here's how it works: 8 | 9 | 1. Fork the repository 10 | 2. Create a new branch for your feature/fix 11 | 3. Make your changes 12 | 4. Submit a pull request 13 | 14 | ### Detailed Steps 15 | 16 | 1. **Fork the Repository** 17 | - Click the "Fork" button on the top right of the repository page 18 | - Clone your fork locally: 19 | ```bash 20 | git clone https://github.com/YOUR-USERNAME/sparse_transformers.git 21 | cd Sparse Transformers 22 | ``` 23 | 24 | 2. **Create a Branch** 25 | - Create a new branch for your changes: 26 | ```bash 27 | git checkout -b feature/your-feature-name 28 | ``` 29 | 30 | 3. **Make Changes** 31 | - Make your changes and commit them 32 | - Write clear commit messages 33 | - Test your changes thoroughly 34 | 35 | 4. **Submit a Pull Request** 36 | - Push your branch to your fork 37 | - Create a pull request against the main repository 38 | - Fill out the pull request template completely 39 | 40 | ## Developer Certificate of Origin (DCO) 41 | 42 | This project requires all contributors to sign off on their commits. This is done through the Developer Certificate of Origin (DCO). The DCO is a lightweight way for contributors to certify that they wrote or otherwise have the right to submit the code they are contributing to the project. 43 | 44 | ### How to Sign Off 45 | 46 | Each commit message must include a Signed-off-by line with your name and email address. You can add this automatically using the `-s` flag when committing: 47 | 48 | ```bash 49 | git commit -s -m "Your commit message" 50 | ``` 51 | 52 | The sign-off line should look like this: 53 | ``` 54 | Signed-off-by: Your Name 55 | ``` 56 | 57 | For more information about the DCO, please visit [DCO App Documentation](https://github.com/dcoapp/app#how-it-works). 58 | 59 | ## Pull Request Requirements 60 | 61 | 1. **Code Quality** 62 | - Follow the existing code style 63 | - Write clear, maintainable code 64 | - Include appropriate tests 65 | - Update documentation as needed 66 | 67 | 2. **Commit Messages** 68 | - Use clear, descriptive commit messages 69 | - Include the DCO sign-off in each commit 70 | - Reference any related issues 71 | 72 | 3. **Pull Request Description** 73 | - Clearly describe the changes 74 | - Reference any related issues 75 | - Ensure all tests and checks are passing 76 | 77 | ## Getting Help 78 | 79 | If you need help or have questions: 80 | - Open an issue 81 | - Join our community discussions 82 | - Reach out to the maintainers 83 | 84 | Thank you for contributing to Sparse Transformers! 85 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fused Sparse C++ Kernels for Transformers 2 | 3 | ## Overview 4 | 5 | The project implements sparse multiplication and fuses up/down projections in the MLP layers through low rank weight activations. 6 | Work is based on [Deja Vu](https://arxiv.org/abs/2310.17157) and Apple's [LLM in a Flash](https://arxiv.org/abs/2312.11514). 7 | 8 | ### Benefits 9 | - **1.6-1.8x overall gain in TTFT and TPS** (4-5x gain in MLP Inference) 10 | - **26.4%** reduction in memory usage 11 | - **6.7×** faster index selection and replacement for weight caching 12 | 13 | 14 | ``` 15 | ┌─────────────────────────────────────────────────────────────────┐ 16 | │ Sparse LLM Inference Pipeline │ 17 | ├─────────────────────────────────────────────────────────────────┤ 18 | │ Sparsity Selection │ 19 | │ ├─ Hidden States → LoRA Projection (Importance Scoring) │ 20 | │ ├─ Binary Mask Generation: (scores > threshold) │ 21 | │ └─ Mask Normalization: Union across batch dimension │ 22 | ├─────────────────────────────────────────────────────────────────┤ 23 | │ Differential Weight Caching │ 24 | │ ├─ Mask Change Detection: XOR with previous mask │ 25 | │ ├─ Paired Replacement: Direct substitution algorithm │ 26 | │ └─ Zero-Copy Tensor Views: torch::from_blob references │ 27 | ├─────────────────────────────────────────────────────────────────┤ 28 | │ Sparse Computation │ 29 | │ ├─ Concatenated Gate+Up Projection (Fused Operation) │ 30 | │ ├─ Element-wise Activation: σ(gate) ⊙ up │ 31 | │ └─ Sparse Down Projection: Only active intermediate dims │ 32 | └─────────────────────────────────────────────────────────────────┘ 33 | ``` 34 | 35 | **Keywords:** Large Language Models, Sparse Inference, Differential Weight Caching 36 | 37 | ## Performance Benchmarks 38 | State of Implementation: 39 | - [x] Torch CPU kernels for fp16, fp32 40 | - [x] Differential weight caching and selection for dynamic sparsity 41 | - [ ] CUDA kernels for Sparse Inferencing 42 | - [ ] CPU kernels for int8, int32, int64 43 | 44 | ### CPU Performance 45 | ``` 46 | Sparse LLaMA 3.2 3B vs LLaMA 3.2 3B (on HuggingFace Implementation): 47 | 48 | - Time to First Token (TTFT): 1.51× faster (1.209s → 0.803s) 49 | - Output Generation Speed: 1.79× faster (0.7 → 1.2 tokens/sec) 50 | - Total Throughput: 1.78× faster (0.7 → 1.3 tokens/sec) 51 | - Memory Usage: 26.4% reduction (13.25GB → 9.75GB) 52 | ``` 53 | 54 | ### GPU Performance 55 | 56 | ``` 57 | Sparse LLaMA 3.2 3B vs Standard LLaMA 3.2 3B CUDA Results (on HuggingFace Implementation): 58 | 59 | - Average time (Sparse): 0.021s 60 | - Average time (Standard): 0.018s 61 | - CUDA Speedups: 0.86x (WIP) 62 | ``` 63 | 64 | ## Usage 65 | 66 | ### Quick Benchmark 67 | 68 | ```bash 69 | # Run comprehensive benchmark 70 | 71 | python run_benchmark.py \ 72 | --device cpu \ # Device: 'cpu' or 'cuda' 73 | --config configs/llama_skip_causal_3b.json \ # Model configuration 74 | --num_runs 50 \ # Number of benchmark runs 75 | --verbose True # Detailed timing output 76 | 77 | # Expected output: 78 | # ⚡ TTFT Speedup: 1.51x 79 | # 🚀 Output TPS Speedup: 1.79x 80 | # 📊 Total Throughput Speedup: 1.78x 81 | ``` 82 | 83 | ## Implementation Details 84 | 85 | ### Paired Replacement with Differential Caching 86 | _sparse_transformers/csrc/weight_cache.h_ 87 | 88 | The weight cache is a class that manages the active weights for the sparse MLP. It differentially updates the MLP tensor memory pool for the next token based on the predicted sparsity mask. 89 | 90 | ```cpp 91 | class WeightCache { 92 | // Paired replacement algorithm for differential updates 93 | void update_active_weights(const torch::Tensor &mask) 94 | 95 | }; 96 | ``` 97 | 98 | **Performance Impact:** 99 | - **6.7× faster cache updates**: 29.89ms (naive `index_select`) → 4.46ms (paired replacement) 100 | - **Better cache locality**: Row major for Up Projection and Column major for Down Projection Matrices 101 | - **Contiguous Memory Access**: Single memcpy for cache updates 102 | 103 | ### Sparse MLP Inference 104 | _sparse_transformers/csrc/sparse_mlp_op.cpp_ 105 | 106 | ```python 107 | sparse_mlp_forward( 108 | x.detach(), 109 | self.weight_cache.get_concat_weight(), 110 | self.weight_cache.get_active_down_weight(), 111 | self.down_proj_buffer, 112 | self.combined_proj_buffer, 113 | "silu" 114 | ) 115 | ``` 116 | 117 | **Performance Impact:** 118 | - **5× faster CPU MLP inference**: 30.1ms → 6.02ms 119 | - OpenMP parallelization with `torch::at::parallel_for` 120 | - Bounded memory usage with weight cache memory pool 121 | 122 | ## Project Structure 123 | 124 | ``` 125 | ├── sparse_transformers/ # C++ extension module 126 | │ ├── csrc/ 127 | │ │ ├── sparse_mlp_op.cpp # Main CPU/CUDA dispatcher 128 | │ │ ├── sparse_mlp_cuda.cu # CUDA kernels 129 | │ │ └── weight_cache.h # Paired replacement caching 130 | │ ├── __init__.py # Python bindings 131 | │ └── CMakeLists.txt # Build configuration 132 | ├── src/models/llama/ 133 | │ ├── modelling_llama_skip.py # Statistical sparsity model 134 | │ └── configuration_llama_skip.py # Model configuration 135 | ├── tools/ 136 | │ └── component_timing.py # Performance profiling 137 | └── run_benchmark.py # End-to-end benchmarks 138 | ``` 139 | 140 | ## Installation 141 | 142 | 143 | ### Build C++ Extensions 144 | ```bash 145 | # Clone repository 146 | git clone https://github.com/nimbleedge/sparse_transformers.git 147 | cd sparse_transformers 148 | 149 | # Install in editable mode (builds C++ extensions automatically) 150 | pip install -r requirements.txt 151 | pip install -e . 152 | 153 | # Verify installation 154 | python -c "import sparse_transformers; print('✅ Installation successful!')" 155 | ``` 156 | 157 | ## Contributing 158 | We welcome contributions from the community! Areas of particular interest are: 159 | - **Additional models**: Extend beyond LLaMA to other architectures 160 | - **Quantization**: Combine with INT8/FP16 optimizations 161 | - **Attention Kernels**: Implement Sparse Attention Kernels 162 | 163 | Please read our [Contributing Guidelines](CONTRIBUTING.md) to get started. 164 | 165 | ## License 166 | 167 | This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. 168 | 169 | -------------------------------------------------------------------------------- /benchmarks/llama1b/benchmark_cpu_l1.log: -------------------------------------------------------------------------------- 1 | 2 | Configuring for 8 CPU threads 3 | 4 | System Configuration: 5 | -------------------------------------------------- 6 | OS: Linux 5.15.0-1089-azure 7 | CPU: x86_64 8 | Physical cores: 8 9 | Total cores: 8 10 | Max CPU frequency: 0MHz 11 | Current CPU frequency: 2620MHz 12 | RAM: Total=54.92GB, Available=50.03GB (8.9% used) 13 | 14 | PyTorch version: 2.5.1 15 | CUDA version: 12.4 16 | -------------------------------------------------- 17 | Using devices: cpu, cpu, cpu 18 | 19 | 🎯 Running comprehensive benchmark with 5 diverse prompts... 20 | 📝 Test prompts: ['Short simple prompt', 'Medium recipe prompt', 'Long technical explanation', 'Creative writing prompt', 'Complex analy 21 | tical prompt'] 22 | 23 | === Benchmarking SkipLLaMA === 24 | Model device: cpu 25 | Model dtype: torch.float32 26 | Warming up model... 27 | 28 | Running comprehensive benchmark on 5 prompts... 29 | 30 | Prompt 1/5: Short simple prompt 31 | Max tokens: 50 32 | TTFT: 0.348s 33 | Output TPS: 2.1 34 | Total TPS: 2.4 35 | 36 | Prompt 2/5: Medium recipe prompt 37 | Max tokens: 200 38 | TTFT: 0.378s 39 | Output TPS: 1.0 40 | Total TPS: 1.1 41 | 42 | Prompt 3/5: Long technical explanation 43 | Max tokens: 300 44 | TTFT: 0.444s 45 | Output TPS: 0.8 46 | Total TPS: 0.8 47 | 48 | Prompt 4/5: Creative writing prompt 49 | Max tokens: 400 50 | TTFT: 0.372s 51 | Output TPS: 0.7 52 | Total TPS: 0.7 53 | 54 | Prompt 5/5: Complex analytical prompt 55 | Max tokens: 500 56 | TTFT: 0.438s 57 | Output TPS: 0.5 58 | Total TPS: 0.6 59 | 60 | Running comprehensive benchmark on 5 prompts... 61 | 62 | Prompt 1/5: Short simple prompt 63 | Max tokens: 50 64 | TTFT: 0.765s 65 | Output TPS: 1.3 66 | Total TPS: 1.4 67 | 68 | Prompt 2/5: Medium recipe prompt 69 | Max tokens: 200 70 | TTFT: 0.620s 71 | Output TPS: 0.8 72 | Total TPS: 0.8 73 | 74 | Prompt 3/5: Long technical explanation 75 | Max tokens: 300 76 | TTFT: 0.660s 77 | Output TPS: 0.6 78 | Total TPS: 0.6 79 | 80 | Prompt 4/5: Creative writing prompt 81 | Max tokens: 400 82 | TTFT: 0.768s 83 | Output TPS: 0.5 84 | Total TPS: 0.5 85 | 86 | Prompt 5/5: Complex analytical prompt 87 | Max tokens: 500 88 | TTFT: 0.902s 89 | Output TPS: 0.4 90 | Total TPS: 0.4 91 | 92 | ============================================================ 93 | 📊 SkipLLaMA Benchmark Results 94 | ============================================================ 95 | 📈 Performance Metrics (n=5 prompts): 96 | ---------------------------------------- 97 | ⚡ Time to First Token: 98 | P50: 0.378s 99 | P90: 0.442s 100 | Mean: 0.396s 101 | 🚀 Output Generation Speed: 102 | P50: 0.8 tokens/sec 103 | P90: 1.7 tokens/sec 104 | Mean: 1.0 tokens/sec 105 | 📊 Total Throughput: 106 | P50: 0.8 tokens/sec 107 | P90: 1.9 tokens/sec 108 | Mean: 1.1 tokens/sec 109 | 110 | ============================================================ 111 | 📊 Standard LLaMA (HuggingFace) Benchmark Results 112 | ============================================================ 113 | 📈 Performance Metrics (n=5 prompts): 114 | ---------------------------------------- 115 | ⚡ Time to First Token: 116 | P50: 0.765s 117 | P90: 0.848s 118 | Mean: 0.743s 119 | 🚀 Output Generation Speed: 120 | P50: 0.6 tokens/sec 121 | P90: 1.1 tokens/sec 122 | Mean: 0.7 tokens/sec 123 | 📊 Total Throughput: 124 | P50: 0.6 tokens/sec 125 | P90: 1.2 tokens/sec 126 | Mean: 0.8 tokens/sec 127 | 128 | ============================================================ 129 | 🏁 Performance Comparison 130 | ============================================================ 131 | ⚡ TTFT Speedup: 1.88x 132 | 🚀 Output TPS Speedup: 1.43x 133 | 📊 Total Throughput Speedup: 1.44x -------------------------------------------------------------------------------- /benchmarks/llama1b/benchmark_gpu.log: -------------------------------------------------------------------------------- 1 | Device set to use cuda 2 | Device set to use cuda 3 | Configuring for 8 CPU threads 4 | 5 | System Configuration: 6 | -------------------------------------------------- 7 | OS: Linux 5.15.0-1079-azure 8 | CPU: x86_64 9 | Physical cores: 8 10 | Total cores: 8 11 | Max CPU frequency: 0MHz 12 | Current CPU frequency: 2544MHz 13 | RAM: Total=54.92GB, Available=51.71GB (5.8% used) 14 | 15 | GPU Configuration: 16 | -------------------------------------------------- 17 | 18 | GPU 0: Tesla T4 19 | Compute capability: 7.5 20 | Total memory: 15.56GB 21 | Free memory: 15.37GB 22 | Multi processors: 40 23 | 24 | PyTorch version: 2.5.1 25 | CUDA version: 12.4 26 | -------------------------------------------------- 27 | Number of available GPUs: 1 28 | Using devices: cuda, cuda, cuda 29 | 30 | Running CUDA inference benchmarks... 31 | -------------------------------------------------- 32 | Warming up models... 33 | 34 | Model type: 35 | Model device: cuda 36 | Model path: meta-llama/Llama-3.2-1B-Instruct 37 | 38 | Model type: 39 | Model device: cuda 40 | Model path: meta-llama/Llama-3.2-1B-Instruct 41 | 42 | Model type: 43 | Model device: cuda 44 | Model path: meta-llama/Llama-3.2-1B-Instruct 45 | 46 | Model type: 47 | Model device: cuda 48 | Model path: meta-llama/Llama-3.2-1B-Instruct 49 | 50 | SkipLLaMA Scripted CUDA Results: 51 | Average time: 0.021s 52 | Min time: 0.020s 53 | Max time: 0.021s 54 | Individual times: ['0.021s', '0.021s', '0.021s', '0.021s', '0.021s', '0.020s', '0.021s', '0.021s', '0.020s', '0.021s'] 55 | 56 | Standard LLaMA CUDA Results: 57 | Average time: 0.018s 58 | Min time: 0.018s 59 | Max time: 0.018s 60 | Individual times: ['0.018s', '0.018s', '0.018s', '0.018s', '0.018s', '0.018s', '0.018s', '0.018s', '0.018s', '0.018s'] 61 | 62 | CUDA Speedups: 63 | Scripted vs Standard: 0.86x 64 | -------------------------------------------------------------------------------- /benchmarks/llama1b/summary.json: -------------------------------------------------------------------------------- 1 | { 2 | "total_sequences_analyzed": 36592, 3 | "total_layers": 16, 4 | "sparsity_thresholds": [ 5 | 0.1, 6 | 0.2, 7 | 0.5, 8 | 0.8, 9 | 0.9, 10 | 0.95, 11 | 0.99 12 | ], 13 | "context_window": 32, 14 | "consistency_metrics": { 15 | "layer_0_heavy_hitter_concentration": 0.28764301444491347, 16 | "layer_1_heavy_hitter_concentration": 0.26758001984588087, 17 | "layer_2_heavy_hitter_concentration": 0.2492908233089502, 18 | "layer_3_heavy_hitter_concentration": 0.23655857881149206, 19 | "layer_4_heavy_hitter_concentration": 0.24199748843684968, 20 | "layer_5_heavy_hitter_concentration": 0.27164510548196347, 21 | "layer_6_heavy_hitter_concentration": 0.28624969237830694, 22 | "layer_7_heavy_hitter_concentration": 0.3051159266092369, 23 | "layer_8_heavy_hitter_concentration": 0.29278576408485774, 24 | "layer_9_heavy_hitter_concentration": 0.29929179821281676, 25 | "layer_10_heavy_hitter_concentration": 0.27972099832263836, 26 | "layer_11_heavy_hitter_concentration": 0.28360533151210754, 27 | "layer_12_heavy_hitter_concentration": 0.3195107774305012, 28 | "layer_13_heavy_hitter_concentration": 0.31951920903150943, 29 | "layer_14_heavy_hitter_concentration": 0.40493080026960043, 30 | "layer_15_heavy_hitter_concentration": 0.5334573832631666, 31 | "mean_context_overlap": 0.05262602819239267, 32 | "std_context_overlap": 0.005512978210303243 33 | } 34 | } -------------------------------------------------------------------------------- /benchmarks/llama3b/benchmark_cpu_l3.log: -------------------------------------------------------------------------------- 1 | Configuration: configs/llama_skip_causal_3b.json 2 | 3 | Configuring for 8 CPU threads 4 | 5 | System Configuration: 6 | -------------------------------------------------- 7 | OS: Linux 5.15.0-1089-azure 8 | CPU: x86_64 9 | Physical cores: 8 10 | Total cores: 8 11 | Max CPU frequency: 0MHz 12 | Current CPU frequency: 2546MHz 13 | RAM: Total=54.92GB, Available=45.06GB (18.0% used) 14 | 15 | PyTorch version: 2.5.1 16 | CUDA version: 12.4 17 | -------------------------------------------------- 18 | Using devices: cpu, cpu, cpu 19 | 20 | Loading checkpoint shards: 0%| | 0/2 [00:00 [activations] 49 | self.token_sparsity = defaultdict(list) # Layer -> [sparsity_ratios] 50 | self.heavy_hitters = defaultdict(Counter) # Layer -> {neuron_id: frequency} 51 | self.contextual_patterns = defaultdict(list) # Context -> [active_neurons] 52 | self.sequence_patterns = [] # [(context_tokens, active_neurons)] 53 | 54 | # Analysis parameters 55 | self.sparsity_thresholds = [0.1, 0.2, 0.5, 0.8, 0.9, 0.95, 0.99] 56 | self.context_window = config.get('context_window', 32) 57 | self.max_sequences = config.get('max_sequences', 1000) 58 | 59 | # Hook handles for activation collection 60 | self.hook_handles = [] 61 | self.current_activations = {} 62 | 63 | def register_hooks(self): 64 | """Register forward hooks to collect MLP activations.""" 65 | def create_hook(layer_name): 66 | def hook_fn(module, input, output): 67 | # Store intermediate activations (after gate * up, before down projection) 68 | if hasattr(module, 'gate_proj') and hasattr(module, 'up_proj'): 69 | gate_out = module.act_fn(module.gate_proj(input[0])) 70 | up_out = module.up_proj(input[0]) 71 | intermediate = gate_out * up_out # Element-wise multiplication 72 | self.current_activations[layer_name] = intermediate.detach() 73 | 74 | return hook_fn 75 | 76 | # Register hooks for all MLP layers 77 | for i, layer in enumerate(self.model.model.layers): 78 | layer_name = f"layer_{i}" 79 | handle = layer.mlp.register_forward_hook(create_hook(layer_name)) 80 | self.hook_handles.append(handle) 81 | 82 | def remove_hooks(self): 83 | """Remove all registered hooks.""" 84 | for handle in self.hook_handles: 85 | handle.remove() 86 | self.hook_handles.clear() 87 | 88 | def compute_token_sparsity(self, activations: torch.Tensor, thresholds: List[float]) -> Dict[float, float]: 89 | """Compute sparsity ratios at different thresholds for a single token.""" 90 | # activations: [batch_size, seq_len, intermediate_size] 91 | abs_activations = torch.abs(activations) 92 | 93 | # Convert to float32 for quantile computation (quantile requires float32/float64) 94 | if abs_activations.dtype == torch.float16: 95 | abs_activations = abs_activations.float() 96 | 97 | sparsity_ratios = {} 98 | for threshold in thresholds: 99 | # Compute threshold value (percentile of activation magnitudes) 100 | threshold_value = torch.quantile(abs_activations, threshold, dim=-1, keepdim=True) 101 | active_mask = abs_activations > threshold_value 102 | sparsity_ratio = active_mask.float().mean().item() 103 | sparsity_ratios[threshold] = 1.0 - sparsity_ratio # Convert to sparsity 104 | 105 | return sparsity_ratios 106 | 107 | def get_top_k_neurons(self, activations: torch.Tensor, k: int) -> Set[int]: 108 | """Get top-k most active neurons for a token.""" 109 | # activations: [intermediate_size] 110 | abs_activations = torch.abs(activations) 111 | 112 | # Ensure we can compute topk (usually works with float16, but ensure compatibility) 113 | if abs_activations.dtype == torch.float16 and not abs_activations.is_cuda: 114 | abs_activations = abs_activations.float() 115 | 116 | _, top_indices = torch.topk(abs_activations, k) 117 | return set(top_indices.cpu().numpy()) 118 | 119 | def analyze_contextual_similarity(self, context_tokens: List[int], active_neurons: Set[int]): 120 | """ 121 | Analyze similarity between current context and previous contexts. 122 | 123 | Note: Uses FULL context (all preceding tokens) instead of a sliding window 124 | to capture complete contextual patterns and long-range dependencies. 125 | This provides more accurate contextual sparsity analysis at the cost of 126 | more memory usage and potentially fewer exact context matches. 127 | """ 128 | # Use full context instead of limiting to context_window 129 | # This captures complete contextual information for better sparsity analysis 130 | context_key = tuple(context_tokens) # Use all available tokens as context 131 | self.contextual_patterns[context_key].append(active_neurons) 132 | 133 | # Store for sequence-level analysis 134 | self.sequence_patterns.append((context_tokens.copy(), active_neurons.copy())) 135 | 136 | def compute_heavy_hitters(self, layer_name: str, active_neurons: Set[int]): 137 | """Track neurons that are frequently active (heavy hitters).""" 138 | for neuron_id in active_neurons: 139 | self.heavy_hitters[layer_name][neuron_id] += 1 140 | 141 | def process_sequence(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Dict: 142 | """Process a single sequence and collect sparsity measurements.""" 143 | batch_size, seq_len = input_ids.shape 144 | results = { 145 | 'sequence_length': seq_len, 146 | 'layer_sparsity': defaultdict(list), 147 | 'token_patterns': [], 148 | 'heavy_hitters_per_layer': defaultdict(set) 149 | } 150 | 151 | self.model.eval() 152 | with torch.no_grad(): 153 | # Process sequence token by token for autoregressive analysis 154 | context_tokens = [] 155 | 156 | for pos in range(1, seq_len): # Start from position 1 (predicting token 1 from token 0) 157 | # Current context (up to position pos-1) 158 | current_input = input_ids[:, :pos] 159 | current_mask = attention_mask[:, :pos] if attention_mask is not None else None 160 | 161 | # Clear previous activations 162 | self.current_activations.clear() 163 | 164 | # Forward pass to collect activations 165 | if input_ids.is_cuda: 166 | with torch.amp.autocast('cuda'): 167 | outputs = self.model( 168 | input_ids=current_input, 169 | attention_mask=current_mask, 170 | output_hidden_states=False, 171 | use_cache=False 172 | ) 173 | else: 174 | with torch.no_grad(): 175 | outputs = self.model( 176 | input_ids=current_input, 177 | attention_mask=current_mask, 178 | output_hidden_states=False, 179 | use_cache=False 180 | ) 181 | 182 | # Analyze activations for each layer 183 | current_token = input_ids[0, pos].item() 184 | context_tokens.append(current_token) 185 | 186 | token_analysis = { 187 | 'position': pos, 188 | 'token_id': current_token, 189 | 'layer_analysis': {} 190 | } 191 | 192 | for layer_name, activations in self.current_activations.items(): 193 | # Get activations for the last token position 194 | last_token_activations = activations[0, -1, :] # [intermediate_size] 195 | 196 | # Compute sparsity ratios 197 | sparsity_ratios = self.compute_token_sparsity( 198 | last_token_activations.unsqueeze(0).unsqueeze(0), 199 | self.sparsity_thresholds 200 | ) 201 | 202 | # Get top-k active neurons (using 10% as active threshold) 203 | k = max(1, int(0.1 * last_token_activations.shape[0])) 204 | active_neurons = self.get_top_k_neurons(last_token_activations, k) 205 | 206 | # Track heavy hitters 207 | self.compute_heavy_hitters(layer_name, active_neurons) 208 | 209 | # Store analysis 210 | token_analysis['layer_analysis'][layer_name] = { 211 | 'sparsity_ratios': sparsity_ratios, 212 | 'active_neurons': list(active_neurons), 213 | 'activation_stats': { 214 | 'mean': last_token_activations.mean().item(), 215 | 'std': last_token_activations.std().item(), 216 | 'max': last_token_activations.max().item(), 217 | 'min': last_token_activations.min().item() 218 | } 219 | } 220 | 221 | # Store for layer-wise analysis 222 | results['layer_sparsity'][layer_name].append(sparsity_ratios) 223 | results['heavy_hitters_per_layer'][layer_name].update(active_neurons) 224 | 225 | # Analyze contextual patterns 226 | if len(context_tokens) >= 4: # Need some context 227 | for layer_name, activations in self.current_activations.items(): 228 | last_token_activations = activations[0, -1, :] 229 | k = max(1, int(0.1 * last_token_activations.shape[0])) 230 | active_neurons = self.get_top_k_neurons(last_token_activations, k) 231 | self.analyze_contextual_similarity(context_tokens, active_neurons) 232 | 233 | results['token_patterns'].append(token_analysis) 234 | 235 | return results 236 | 237 | def compute_contextual_consistency(self) -> Dict[str, float]: 238 | """Compute consistency metrics for contextual patterns.""" 239 | consistency_metrics = {} 240 | 241 | for layer_name in self.heavy_hitters.keys(): 242 | # Heavy hitter concentration 243 | total_activations = sum(self.heavy_hitters[layer_name].values()) 244 | if total_activations > 0: 245 | # Top 10% neurons concentration 246 | sorted_neurons = sorted(self.heavy_hitters[layer_name].items(), 247 | key=lambda x: x[1], reverse=True) 248 | top_10_percent = int(0.1 * len(sorted_neurons)) or 1 249 | top_concentration = sum(count for _, count in sorted_neurons[:top_10_percent]) 250 | consistency_metrics[f'{layer_name}_heavy_hitter_concentration'] = top_concentration / total_activations 251 | 252 | # Context similarity metrics 253 | context_overlap_scores = [] 254 | for context_key, neuron_lists in self.contextual_patterns.items(): 255 | if len(neuron_lists) > 1: 256 | # Compute pairwise Jaccard similarity 257 | similarities = [] 258 | for i in range(len(neuron_lists)): 259 | for j in range(i + 1, len(neuron_lists)): 260 | set1, set2 = neuron_lists[i], neuron_lists[j] 261 | if len(set1) > 0 and len(set2) > 0: 262 | intersection = len(set1.intersection(set2)) 263 | union = len(set1.union(set2)) 264 | similarities.append(intersection / union if union > 0 else 0.0) 265 | 266 | if similarities: 267 | context_overlap_scores.extend(similarities) 268 | 269 | if context_overlap_scores: 270 | consistency_metrics['mean_context_overlap'] = np.mean(context_overlap_scores) 271 | consistency_metrics['std_context_overlap'] = np.std(context_overlap_scores) 272 | 273 | return consistency_metrics 274 | 275 | def save_results(self, save_dir: str): 276 | """Save analysis results to files.""" 277 | os.makedirs(save_dir, exist_ok=True) 278 | 279 | # Save raw data 280 | results = { 281 | 'heavy_hitters': dict(self.heavy_hitters), 282 | 'contextual_patterns': {str(k): [list(s) for s in v] 283 | for k, v in self.contextual_patterns.items()}, 284 | 'sequence_patterns': [(tokens, list(neurons)) 285 | for tokens, neurons in self.sequence_patterns], 286 | 'consistency_metrics': self.compute_contextual_consistency() 287 | } 288 | 289 | with open(os.path.join(save_dir, 'sparsity_analysis.pkl'), 'wb') as f: 290 | pickle.dump(results, f) 291 | 292 | # Save summary statistics 293 | summary = { 294 | 'total_sequences_analyzed': len(self.sequence_patterns), 295 | 'total_layers': len(self.heavy_hitters), 296 | 'sparsity_thresholds': self.sparsity_thresholds, 297 | 'context_window': self.context_window, 298 | 'consistency_metrics': results['consistency_metrics'] 299 | } 300 | 301 | with open(os.path.join(save_dir, 'summary.json'), 'w') as f: 302 | json.dump(summary, f, indent=2) 303 | 304 | logger.info(f"Results saved to {save_dir}") 305 | 306 | def plot_sparsity_analysis(self, save_dir: str): 307 | """Generate visualization plots for sparsity analysis.""" 308 | os.makedirs(save_dir, exist_ok=True) 309 | 310 | # Plot 1: Sparsity ratios across layers and thresholds 311 | plt.figure(figsize=(12, 8)) 312 | 313 | layer_names = sorted(self.heavy_hitters.keys()) 314 | threshold_data = defaultdict(list) 315 | 316 | # Aggregate sparsity data across all processed sequences 317 | for layer_name in layer_names: 318 | layer_sparsity_data = [] 319 | # This would need to be collected during processing 320 | # For now, we'll use heavy hitter data as proxy 321 | total_neurons = 8192 # Typical intermediate size for LLaMA 322 | active_neurons = len(self.heavy_hitters[layer_name]) 323 | sparsity_proxy = 1.0 - (active_neurons / total_neurons) 324 | layer_sparsity_data.append(sparsity_proxy) 325 | 326 | # Plot heavy hitter distribution 327 | plt.figure(figsize=(15, 10)) 328 | 329 | # Subplot 1: Heavy hitter concentration per layer 330 | plt.subplot(2, 2, 1) 331 | layer_indices = range(len(layer_names)) 332 | heavy_hitter_counts = [len(self.heavy_hitters[layer]) for layer in layer_names] 333 | 334 | plt.bar(layer_indices, heavy_hitter_counts) 335 | plt.xlabel('Layer Index') 336 | plt.ylabel('Number of Heavy Hitter Neurons') 337 | plt.title('Heavy Hitter Distribution Across Layers') 338 | plt.xticks(layer_indices[::4], [f'L{i}' for i in layer_indices[::4]]) 339 | 340 | # Subplot 2: Top neurons frequency distribution 341 | plt.subplot(2, 2, 2) 342 | all_frequencies = [] 343 | for layer_name in layer_names[:5]: # Show first 5 layers 344 | frequencies = list(self.heavy_hitters[layer_name].values()) 345 | if frequencies: 346 | all_frequencies.extend(frequencies) 347 | 348 | if all_frequencies: 349 | plt.hist(all_frequencies, bins=50, alpha=0.7, edgecolor='black') 350 | plt.xlabel('Activation Frequency') 351 | plt.ylabel('Number of Neurons') 352 | plt.title('Neuron Activation Frequency Distribution') 353 | plt.yscale('log') 354 | 355 | # Subplot 3: Context similarity distribution 356 | plt.subplot(2, 2, 3) 357 | consistency_metrics = self.compute_contextual_consistency() 358 | if 'mean_context_overlap' in consistency_metrics: 359 | overlap_scores = [] 360 | for neuron_lists in self.contextual_patterns.values(): 361 | if len(neuron_lists) > 1: 362 | for i in range(len(neuron_lists)): 363 | for j in range(i + 1, len(neuron_lists)): 364 | set1, set2 = neuron_lists[i], neuron_lists[j] 365 | if len(set1) > 0 and len(set2) > 0: 366 | intersection = len(set1.intersection(set2)) 367 | union = len(set1.union(set2)) 368 | overlap_scores.append(intersection / union if union > 0 else 0.0) 369 | 370 | if overlap_scores: 371 | plt.hist(overlap_scores, bins=30, alpha=0.7, edgecolor='black') 372 | plt.xlabel('Jaccard Similarity') 373 | plt.ylabel('Frequency') 374 | plt.title('Contextual Pattern Similarity') 375 | 376 | # Subplot 4: Layer-wise heavy hitter concentration 377 | plt.subplot(2, 2, 4) 378 | concentrations = [] 379 | for layer_name in layer_names: 380 | total_activations = sum(self.heavy_hitters[layer_name].values()) 381 | if total_activations > 0: 382 | sorted_neurons = sorted(self.heavy_hitters[layer_name].items(), 383 | key=lambda x: x[1], reverse=True) 384 | top_10_percent = int(0.1 * len(sorted_neurons)) or 1 385 | top_concentration = sum(count for _, count in sorted_neurons[:top_10_percent]) 386 | concentrations.append(top_concentration / total_activations) 387 | else: 388 | concentrations.append(0.0) 389 | 390 | plt.plot(layer_indices, concentrations, 'o-') 391 | plt.xlabel('Layer Index') 392 | plt.ylabel('Top 10% Neuron Concentration') 393 | plt.title('Heavy Hitter Concentration by Layer') 394 | plt.xticks(layer_indices[::4], [f'L{i}' for i in layer_indices[::4]]) 395 | 396 | plt.tight_layout() 397 | plt.savefig(os.path.join(save_dir, 'sparsity_analysis.png'), dpi=300, bbox_inches='tight') 398 | plt.close() 399 | 400 | logger.info(f"Plots saved to {save_dir}") 401 | 402 | 403 | class C4Dataset(Dataset): 404 | """C4 dataset for contextual sparsity analysis.""" 405 | 406 | def __init__(self, tokenizer, max_length: int = 512, num_samples: int = 1000): 407 | self.tokenizer = tokenizer 408 | self.max_length = max_length 409 | 410 | # Load C4 dataset 411 | logger.info("Loading C4 dataset...") 412 | dataset = load_dataset("allenai/c4", "realnewslike", split="train", streaming=True) 413 | 414 | # Process samples 415 | self.samples = [] 416 | for i, sample in enumerate(dataset): 417 | if i >= num_samples: 418 | break 419 | 420 | text = sample['text'] 421 | if len(text.strip()) > 50: # Filter out very short texts 422 | encoding = tokenizer( 423 | text, 424 | truncation=True, 425 | padding=False, 426 | max_length=max_length, 427 | return_tensors='pt' 428 | ) 429 | 430 | if encoding['input_ids'].shape[1] > 10: # Ensure minimum sequence length 431 | self.samples.append({ 432 | 'input_ids': encoding['input_ids'].squeeze(), 433 | 'attention_mask': encoding['attention_mask'].squeeze(), 434 | 'text': text[:200] + "..." if len(text) > 200 else text 435 | }) 436 | 437 | logger.info(f"Loaded {len(self.samples)} C4 samples") 438 | 439 | def __len__(self): 440 | return len(self.samples) 441 | 442 | def __getitem__(self, idx): 443 | return self.samples[idx] 444 | 445 | 446 | def main(): 447 | parser = argparse.ArgumentParser(description="Measure contextual sparsity in LLaMA models") 448 | parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.2-3B-Instruct", 449 | help="HuggingFace model name or path") 450 | parser.add_argument("--output_dir", type=str, required=True, 451 | help="Output directory for results") 452 | parser.add_argument("--num_samples", type=int, default=1000, 453 | help="Number of C4 samples to analyze") 454 | parser.add_argument("--max_length", type=int, default=512, 455 | help="Maximum sequence length") 456 | parser.add_argument("--batch_size", type=int, default=1, 457 | help="Batch size (recommend 1 for token-by-token analysis)") 458 | parser.add_argument("--context_window", type=int, default=32, 459 | help="Context window size for pattern analysis (UNUSED: now using full context)") 460 | parser.add_argument("--device", type=str, default="auto", 461 | help="Device to use (auto, cpu, cuda)") 462 | parser.add_argument("--seed", type=int, default=42, 463 | help="Random seed") 464 | parser.add_argument("--save_plots", action="store_true", 465 | help="Generate and save analysis plots") 466 | 467 | args = parser.parse_args() 468 | 469 | # Set seed 470 | set_seed(args.seed) 471 | 472 | # Setup device 473 | if args.device == "auto": 474 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 475 | else: 476 | device = torch.device(args.device) 477 | 478 | logger.info(f"Using device: {device}") 479 | 480 | # Setup output directory 481 | os.makedirs(args.output_dir, exist_ok=True) 482 | 483 | # Load model and tokenizer 484 | logger.info(f"Loading model: {args.model_name}") 485 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) 486 | if tokenizer.pad_token is None: 487 | tokenizer.pad_token = tokenizer.eos_token 488 | 489 | model = AutoModelForCausalLM.from_pretrained( 490 | args.model_name, 491 | torch_dtype=torch.float16 if device.type == "cuda" else torch.float32, 492 | device_map="auto" if device.type == "cuda" else None, 493 | trust_remote_code=True 494 | ) 495 | 496 | if device.type != "cuda": 497 | model = model.to(device) 498 | 499 | # Load C4 dataset 500 | dataset = C4Dataset(tokenizer, args.max_length, args.num_samples) 501 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False) 502 | 503 | # Initialize sparsity analyzer 504 | config = { 505 | 'context_window': args.context_window, 506 | 'max_sequences': args.num_samples 507 | } 508 | analyzer = ContextualSparsityAnalyzer(model, tokenizer, config) 509 | 510 | # Register hooks for activation collection 511 | analyzer.register_hooks() 512 | 513 | try: 514 | # Process dataset 515 | logger.info("Starting contextual sparsity analysis...") 516 | 517 | for batch_idx, batch in enumerate(tqdm(dataloader, desc="Analyzing sequences")): 518 | input_ids = batch['input_ids'].to(device) 519 | attention_mask = batch['attention_mask'].to(device) 520 | 521 | # Process sequence 522 | results = analyzer.process_sequence(input_ids, attention_mask) 523 | 524 | # Log progress 525 | if (batch_idx + 1) % 100 == 0: 526 | logger.info(f"Processed {batch_idx + 1}/{len(dataloader)} sequences") 527 | 528 | # Intermediate consistency metrics 529 | consistency = analyzer.compute_contextual_consistency() 530 | if 'mean_context_overlap' in consistency: 531 | logger.info(f"Current mean context overlap: {consistency['mean_context_overlap']:.4f}") 532 | 533 | # Save results 534 | logger.info("Saving analysis results...") 535 | analyzer.save_results(args.output_dir) 536 | 537 | # Generate plots if requested 538 | if args.save_plots: 539 | logger.info("Generating visualization plots...") 540 | analyzer.plot_sparsity_analysis(args.output_dir) 541 | 542 | # Print final summary 543 | consistency_metrics = analyzer.compute_contextual_consistency() 544 | 545 | print(f"\n{'='*60}") 546 | print(f"🎯 CONTEXTUAL SPARSITY ANALYSIS SUMMARY") 547 | print(f"{'='*60}") 548 | print(f"📊 Model: {args.model_name}") 549 | print(f"📝 Sequences analyzed: {len(analyzer.sequence_patterns)}") 550 | print(f"🧠 Layers analyzed: {len(analyzer.heavy_hitters)}") 551 | print(f"🔍 Context patterns found: {len(analyzer.contextual_patterns)}") 552 | 553 | print(f"\n📈 Consistency Metrics:") 554 | for metric, value in consistency_metrics.items(): 555 | if isinstance(value, float): 556 | print(f" {metric}: {value:.4f}") 557 | 558 | # Heavy hitter summary 559 | print(f"\n🎯 Heavy Hitter Analysis:") 560 | for layer_name in sorted(list(analyzer.heavy_hitters.keys())[:5]): # Show first 5 layers 561 | heavy_count = len(analyzer.heavy_hitters[layer_name]) 562 | total_activations = sum(analyzer.heavy_hitters[layer_name].values()) 563 | print(f" {layer_name}: {heavy_count} heavy hitters, {total_activations} total activations") 564 | 565 | print(f"\n✅ Analysis completed! Results saved to: {args.output_dir}") 566 | print(f"📁 Files generated:") 567 | print(f" - sparsity_analysis.pkl (raw data)") 568 | print(f" - summary.json (summary statistics)") 569 | if args.save_plots: 570 | print(f" - sparsity_analysis.png (visualization)") 571 | 572 | finally: 573 | # Clean up hooks 574 | analyzer.remove_hooks() 575 | 576 | 577 | if __name__ == "__main__": 578 | main() -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64", "wheel", "torch"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "sparse_transformers" 7 | version = "0.0.1" 8 | description = "Sparse Inference for transformers" 9 | authors = [ 10 | {name = "NimbleEdge"} 11 | ] 12 | requires-python = ">=3.7" 13 | 14 | [tool.setuptools] 15 | packages = ["sparse_transformers"] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.34.2 2 | aiohappyeyeballs==2.4.2 3 | aiohttp==3.10.6 4 | aiosignal==1.3.1 5 | annotated-types==0.7.0 6 | anyio==4.5.0 7 | appdirs==1.4.3 8 | async-timeout==4.0.3 9 | attrs==25.1.0 10 | audioread==3.0.1 11 | Automat==0.8.0 12 | beautifulsoup4==4.12.3 13 | bitsandbytes==0.42.0 14 | bleach==6.1.0 15 | blinker==1.4 16 | certifi==2019.11.28 17 | cffi==1.17.1 18 | chardet==3.0.4 19 | charset-normalizer==3.3.2 20 | Click==7.0 21 | cloud-init==24.4 22 | cloudpickle==3.1.1 23 | colorama==0.4.3 24 | coloredlogs==15.0.1 25 | command-not-found==0.3 26 | compressed-tensors==0.6.0 27 | configobj==5.0.6 28 | constantly==15.1.0 29 | contourpy==1.1.1 30 | cryptography==2.8 31 | cupshelpers==1.0 32 | cycler==0.12.1 33 | Cython==3.0.11 34 | dataclasses-json==0.6.7 35 | datasets==3.0.1 36 | dbus-python==1.2.16 37 | decorator==4.4.2 38 | defer==1.0.6 39 | defusedxml==0.7.1 40 | dill==0.3.8 41 | diskcache==5.6.3 42 | Distance==0.1.3 43 | distlib==0.3.0 44 | distro==1.9.0 45 | distro-info==0.23+ubuntu1.1 46 | einops==0.8.0 47 | entrypoints==0.3 48 | exceptiongroup==1.2.2 49 | faiss-gpu==1.7.2 50 | fastapi==0.112.4 51 | fastjsonschema==2.20.0 52 | filelock==3.16.1 53 | flatbuffers==24.3.25 54 | fonttools==4.54.1 55 | frozenlist==1.4.1 56 | fsspec==2024.6.1 57 | g2p-en==2.1.0 58 | gguf==0.10.0 59 | greenlet==3.1.1 60 | h11==0.14.0 61 | httpcore==1.0.5 62 | httplib2==0.14.0 63 | httptools==0.6.4 64 | httpx==0.27.2 65 | huggingface-hub==0.25.1 66 | humanfriendly==10.0 67 | hyperlink==19.0.0 68 | idna==2.8 69 | importlib_metadata==8.5.0 70 | importlib_resources==6.4.5 71 | incremental==16.10.1 72 | inflect==7.4.0 73 | interegular==0.3.3 74 | ipython_genutils==0.2.0 75 | jaro-winkler==2.0.3 76 | Jinja2==3.1.4 77 | jiter==0.8.2 78 | joblib==1.4.2 79 | jsonpatch==1.33 80 | jsonpointer==2.0 81 | jsonschema==4.23.0 82 | jsonschema-specifications==2023.12.1 83 | jupyter_client==8.6.3 84 | jupyter_core==5.7.2 85 | jupyterlab_pygments==0.3.0 86 | keyring==18.0.1 87 | kiwisolver==1.4.7 88 | langchain==0.2.16 89 | langchain-community==0.2.17 90 | langchain-core==0.2.41 91 | langchain-text-splitters==0.2.4 92 | langsmith==0.1.129 93 | language-selector==0.1 94 | lark==1.2.2 95 | launchpadlib==1.10.13 96 | lazr.restfulclient==0.14.2 97 | lazr.uri==1.0.3 98 | lazy_loader==0.4 99 | librosa==0.10.2.post1 100 | lightning==2.3.3 101 | lightning-utilities==0.11.9 102 | llvmlite==0.41.1 103 | lm-format-enforcer==0.10.6 104 | macaroonbakery==1.3.1 105 | MarkupSafe==2.1.5 106 | marshmallow==3.22.0 107 | matplotlib==3.7.5 108 | mistral_common==1.5.2 109 | mistune==3.0.2 110 | more-itertools==10.5.0 111 | mpmath==1.3.0 112 | msgpack==1.1.0 113 | msgspec==0.18.6 114 | multidict==6.1.0 115 | multiprocess==0.70.16 116 | mypy-extensions==1.0.0 117 | nbclient==0.10.0 118 | nbconvert==7.16.4 119 | nbformat==5.10.4 120 | nest-asyncio==1.6.0 121 | netifaces==0.10.4 122 | networkx==3.1 123 | nltk==3.9.1 124 | numba==0.58.1 125 | numpy==1.23.5 126 | nvidia-cublas-cu12==12.1.3.1 127 | nvidia-cuda-cupti-cu12==12.1.105 128 | nvidia-cuda-nvrtc-cu12==12.1.105 129 | nvidia-cuda-runtime-cu12==12.1.105 130 | nvidia-cudnn-cu12==9.1.0.70 131 | nvidia-cufft-cu12==11.0.2.54 132 | nvidia-curand-cu12==10.3.2.106 133 | nvidia-cusolver-cu12==11.4.5.107 134 | nvidia-cusparse-cu12==12.1.0.106 135 | nvidia-ml-py==12.570.86 136 | nvidia-nccl-cu12==2.20.5 137 | nvidia-nvjitlink-cu12==12.6.68 138 | nvidia-nvtx-cu12==12.1.105 139 | oauthlib==3.1.0 140 | onnx==1.16.2 141 | onnxruntime==1.16.3 142 | openai==1.61.1 143 | opencv-python-headless==4.11.0.86 144 | optimum==1.22.0 145 | orjson==3.10.7 146 | outlines==0.0.46 147 | packaging==24.1 148 | pandas==2.0.3 149 | pandocfilters==1.5.1 150 | partial-json-parser==0.2.1.1.post5 151 | pexpect==4.6.0 152 | pillow==10.4.0 153 | pipenv==11.9.0 154 | pkgutil_resolve_name==1.3.10 155 | platformdirs==4.3.6 156 | pooch==1.8.2 157 | prometheus-fastapi-instrumentator==7.0.2 158 | prometheus_client==0.21.1 159 | protobuf==3.20.2 160 | psutil==6.0.0 161 | py-cpuinfo==9.0.0 162 | pyairports==2.1.1 163 | pyarrow==17.0.0 164 | pyasn1==0.4.2 165 | pyasn1-modules==0.2.1 166 | pycairo==1.16.2 167 | pycountry==24.6.1 168 | pycparser==2.22 169 | pycups==1.9.73 170 | pydantic==2.9.2 171 | pydantic_core==2.23.4 172 | Pygments==2.18.0 173 | PyGObject==3.36.0 174 | PyHamcrest==1.9.0 175 | PyJWT==1.7.1 176 | pymacaroons==0.13.0 177 | PyNaCl==1.3.0 178 | pyOpenSSL==19.0.0 179 | pyparsing==3.1.4 180 | pyparted==3.11.2 181 | pyRFC3339==1.1 182 | pyrsistent==0.15.5 183 | pyserial==3.4 184 | python-apt==2.0.1+ubuntu0.20.4.1 185 | python-dateutil==2.9.0.post0 186 | python-debian==0.1.36+ubuntu1.1 187 | python-dotenv==1.0.1 188 | pytorch-lightning==2.4.0 189 | pytz==2024.2 190 | pyworld==0.3.4 191 | PyYAML==6.0.2 192 | pyzmq==26.2.0 193 | ray==2.10.0 194 | referencing==0.35.1 195 | regex==2024.9.11 196 | requests==2.32.3 197 | requests-unixsocket==0.2.0 198 | rpds-py==0.20.1 199 | safetensors==0.4.5 200 | scikit-learn==1.3.2 201 | scipy==1.10.1 202 | screen-resolution-extra==0.0.0 203 | seaborn==0.13.2 204 | SecretStorage==2.3.1 205 | sentence-transformers==3.1.1 206 | sentencepiece==0.2.0 207 | service-identity==18.1.0 208 | simplejson==3.16.0 209 | six==1.14.0 210 | sniffio==1.3.1 211 | sos==4.5.6 212 | sounddevice==0.5.1 213 | soundfile==0.12.1 214 | soupsieve==2.6 215 | soxr==0.3.7 216 | SQLAlchemy==2.0.35 217 | ssh-import-id==5.10 218 | starlette==0.38.6 219 | sympy==1.13.3 220 | systemd-python==234 221 | tenacity==8.5.0 222 | tgt==1.5 223 | threadpoolctl==3.5.0 224 | tiktoken==0.7.0 225 | timm==1.0.9 226 | tinycss2==1.3.0 227 | tokenizers==0.20.3 228 | torch==2.4.0 229 | torchmetrics==1.5.2 230 | torchvision==0.19.0 231 | tornado==6.4.1 232 | tqdm==4.66.5 233 | traitlets==5.14.3 234 | transformers==4.46.3 235 | triton==3.0.0 236 | Twisted==18.9.0 237 | typeguard==4.4.0 238 | typing-inspect==0.9.0 239 | typing_extensions==4.12.2 240 | tzdata==2024.2 241 | ubuntu-advantage-tools==8001 242 | ubuntu-drivers-common==0.0.0 243 | ufw==0.36 244 | unattended-upgrades==0.1 245 | Unidecode==1.3.8 246 | urllib3==1.25.8 247 | uvicorn==0.33.0 248 | uvloop==0.21.0 249 | validators==0.34.0 250 | virtualenv==20.0.17 251 | virtualenv-clone==0.3.0 252 | vllm==0.6.3.post1 253 | wadllib==1.3.3 254 | WALinuxAgent==2.2.46 255 | watchfiles==0.24.0 256 | webencodings==0.5.1 257 | websockets==13.1 258 | xformers==0.0.27.post2 259 | xkit==0.0.0 260 | xxhash==3.5.0 261 | yarl==1.13.0 262 | zipp==3.20.2 263 | zope.interface==4.7.1 264 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 3 | import os 4 | import torch 5 | from pathlib import Path 6 | import shutil 7 | import sys 8 | 9 | # Create build directory if it doesn't exist 10 | build_dir = Path(__file__).parent / 'build' 11 | if build_dir.exists(): 12 | shutil.rmtree(build_dir) 13 | build_dir.mkdir(parents=True) 14 | (build_dir / 'lib').mkdir(exist_ok=True) 15 | 16 | # Set environment variables to control build output 17 | os.environ['TORCH_BUILD_DIR'] = str(build_dir) 18 | os.environ['BUILD_LIB'] = str(build_dir / 'lib') 19 | os.environ['BUILD_TEMP'] = str(build_dir / 'temp') 20 | 21 | # Get CUDA compute capability if GPU is available 22 | arch_flags = [] 23 | if torch.cuda.is_available(): 24 | arch_list = [] 25 | for i in range(torch.cuda.device_count()): 26 | arch_list.append(torch.cuda.get_device_capability(i)) 27 | arch_list = sorted(list(set(arch_list))) 28 | arch_flags = [f"-gencode=arch=compute_{arch[0]}{arch[1]},code=sm_{arch[0]}{arch[1]}" for arch in arch_list] 29 | 30 | # Common optimization flags 31 | common_compile_args = [ 32 | '-O3', # Maximum optimization 33 | '-march=native', # Optimize for local CPU architecture 34 | '-ffast-math', # Aggressive floating point optimizations 35 | '-fopenmp', # OpenMP support 36 | '-flto', # Link-time optimization 37 | '-funroll-loops', # Unroll loops 38 | '-fno-math-errno', # Assume math functions never set errno 39 | '-fno-trapping-math', # Assume FP ops don't generate traps 40 | '-mtune=native', # Tune code for local CPU 41 | ] 42 | 43 | # CPU-specific optimization flags 44 | cpu_compile_args = common_compile_args + [ 45 | '-mavx2', # Enable AVX2 instructions if available 46 | '-mfma', # Enable FMA instructions if available 47 | '-fno-plt', # Improve indirect call performance 48 | '-fuse-linker-plugin', # Enable LTO plugin 49 | '-fomit-frame-pointer', # Remove frame pointers 50 | '-fno-stack-protector', # Disable stack protector 51 | '-fvisibility=hidden', # Hide all symbols by default 52 | '-fdata-sections', # Place each data item into its own section 53 | '-ffunction-sections', # Place each function into its own section 54 | ] 55 | 56 | # CUDA-specific optimization flags 57 | cuda_compile_args = ['-O3', '--use_fast_math'] + arch_flags + [ 58 | '--compiler-options', "'-fPIC'", 59 | '--compiler-options', "'-O3'", 60 | '--compiler-options', "'-march=native'", 61 | '--compiler-options', "'-ffast-math'", 62 | '-std=c++17' # Force C++17 instead of C++20 63 | ] 64 | 65 | # Link flags 66 | extra_link_args = [ 67 | '-fopenmp', 68 | '-flto', # Link-time optimization 69 | '-fuse-linker-plugin', # Enable LTO plugin 70 | '-Wl,--as-needed', # Only link needed libraries 71 | '-Wl,-O3', # Linker optimizations 72 | '-Wl,--strip-all', # Strip all symbols 73 | '-Wl,--gc-sections', # Remove unused sections 74 | '-Wl,--exclude-libs,ALL', # Don't export any symbols from libraries 75 | ] 76 | 77 | # Get CUDA include paths 78 | def get_cuda_include_dirs(): 79 | cuda_home = os.getenv('CUDA_HOME', '/usr/local/cuda') 80 | if not os.path.exists(cuda_home): 81 | cuda_home = os.getenv('CUDA_PATH') # Windows 82 | 83 | if cuda_home is None: 84 | raise RuntimeError("CUDA_HOME or CUDA_PATH environment variable is not set") 85 | 86 | return [ 87 | os.path.join(cuda_home, 'include'), 88 | os.path.join(cuda_home, 'samples', 'common', 'inc') 89 | ] 90 | 91 | # Base extension configuration 92 | base_include_dirs = [ 93 | os.path.dirname(torch.__file__) + '/include', 94 | os.path.dirname(torch.__file__) + '/include/torch/csrc/api/include', 95 | os.path.dirname(torch.__file__) + '/include/ATen', 96 | os.path.dirname(torch.__file__) + '/include/c10', 97 | ] 98 | 99 | if torch.cuda.is_available(): 100 | base_include_dirs.extend(get_cuda_include_dirs()) 101 | 102 | # Define extensions 103 | ext_modules = [] 104 | if torch.cuda.is_available(): 105 | extension = CUDAExtension( 106 | name='sparse_transformers', 107 | sources=[ 108 | 'sparse_transformers/csrc/sparse_mlp_op.cpp', 109 | 'sparse_transformers/csrc/sparse_mlp_cuda.cu' 110 | ], 111 | include_dirs=base_include_dirs, 112 | extra_compile_args={ 113 | 'cxx': cpu_compile_args, 114 | 'nvcc': cuda_compile_args 115 | }, 116 | extra_link_args=extra_link_args, 117 | libraries=['gomp', 'cudart'], 118 | library_dirs=[str(build_dir / 'lib')], 119 | define_macros=[('WITH_CUDA', None)] 120 | ) 121 | else: 122 | extension = CppExtension( 123 | name='sparse_transformers', 124 | sources=['sparse_transformers/csrc/sparse_mlp_op.cpp'], 125 | extra_compile_args=cpu_compile_args, 126 | extra_link_args=extra_link_args, 127 | library_dirs=[str(build_dir / 'lib')], 128 | include_dirs=base_include_dirs, 129 | libraries=['gomp'] 130 | ) 131 | 132 | ext_modules.append(extension) 133 | 134 | # Custom build extension to handle clean builds 135 | class CustomBuildExtension(BuildExtension): 136 | def get_ext_filename(self, ext_name): 137 | # Force output to build directory 138 | filename = super().get_ext_filename(ext_name) 139 | return str(build_dir / 'lib' / os.path.basename(filename)) 140 | 141 | def get_ext_fullpath(self, ext_name): 142 | # Override to ensure extension is built in our build directory 143 | filename = self.get_ext_filename(ext_name) 144 | return str(build_dir / 'lib' / filename) 145 | 146 | def build_extensions(self): 147 | # Clean old build files 148 | if self.parallel: 149 | self.parallel = False # Disable parallel build for CUDA 150 | super().build_extensions() 151 | 152 | setup( 153 | name='sparse_transformers', 154 | packages=find_packages(), 155 | ext_modules=ext_modules, 156 | cmdclass={ 157 | 'build_ext': CustomBuildExtension.with_options(no_python_abi_suffix=True), 158 | } 159 | ) -------------------------------------------------------------------------------- /sparse_transformers/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.1 FATAL_ERROR) 2 | project(sparse_mlp) 3 | 4 | find_package(Torch REQUIRED) 5 | find_package(OpenMP REQUIRED) 6 | 7 | # Define our library target 8 | add_library(sparse_mlp SHARED 9 | csrc/sparse_mlp_op.cpp 10 | ) 11 | 12 | # Enable C++17 13 | target_compile_features(sparse_mlp PRIVATE cxx_std_17) 14 | 15 | # Add OpenMP flags 16 | target_compile_options(sparse_mlp PRIVATE ${OpenMP_CXX_FLAGS}) 17 | 18 | # Include directories 19 | target_include_directories(sparse_mlp PRIVATE 20 | ${TORCH_INCLUDE_DIRS} 21 | ${CMAKE_CURRENT_SOURCE_DIR}/csrc 22 | ) 23 | 24 | # Link against LibTorch and OpenMP 25 | target_link_libraries(sparse_mlp PRIVATE 26 | ${TORCH_LIBRARIES} 27 | OpenMP::OpenMP_CXX 28 | ) 29 | 30 | # Set output directory 31 | set_target_properties(sparse_mlp PROPERTIES 32 | LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/lib" 33 | RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/bin" 34 | ) 35 | 36 | # Add optimization flags 37 | target_compile_options(sparse_mlp PRIVATE 38 | -O3 39 | -ffast-math 40 | -march=native 41 | ) -------------------------------------------------------------------------------- /sparse_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | # Configure CPU threads 4 | num_threads = os.cpu_count() 5 | print(f"Configuring for {num_threads} CPU threads") 6 | os.environ['OMP_NUM_THREADS'] = str(num_threads) 7 | os.environ['MKL_NUM_THREADS'] = str(num_threads) 8 | os.environ['OPENBLAS_NUM_THREADS'] = str(num_threads) 9 | os.environ['VECLIB_MAXIMUM_THREADS'] = str(num_threads) 10 | os.environ['NUMEXPR_NUM_THREADS'] = str(num_threads) 11 | torch.set_num_threads(num_threads) 12 | torch.set_num_interop_threads(num_threads) 13 | os.environ['MAX_JOBS'] = str(num_threads) 14 | 15 | torch.classes.load_library("./build/lib/sparse_transformers.so") 16 | 17 | sparse_mlp_forward = torch.ops.sparse_mlp.forward 18 | WeightCache = torch.classes.sparse_mlp.WeightCache 19 | 20 | # Export Count-Min Sketch approximate top-k function from sparse_mlp library 21 | approx_topk_threshold = torch.ops.sparse_mlp.approx_topk_threshold 22 | 23 | # Re-export for API compatibility 24 | __all__ = [ 25 | 'sparse_mlp_forward', 26 | 'WeightCache', 27 | 'approx_topk_threshold' 28 | ] -------------------------------------------------------------------------------- /sparse_transformers/csrc/approx_topk.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | // Count-Min Sketch inspired method - O(n) time complexity with parallel batch processing 10 | torch::Tensor approx_topk_threshold( 11 | const torch::Tensor &scores, 12 | int64_t k) 13 | { 14 | TORCH_CHECK(scores.dim() == 2, "Input scores must be 2D tensor [batch_size, features]"); 15 | TORCH_CHECK(k > 0, "k must be positive"); 16 | 17 | auto batch_size = scores.size(0); 18 | auto feature_size = scores.size(1); 19 | 20 | TORCH_CHECK(k <= feature_size, "k cannot be larger than feature size"); 21 | 22 | auto options = torch::TensorOptions().dtype(scores.dtype()).device(scores.device()); 23 | auto threshold = torch::zeros({batch_size, 1}, options); 24 | 25 | // Sketch parameters 26 | const int num_sketches = 4; 27 | const int sketch_width = std::min(1024L, feature_size / 4); 28 | 29 | // Standard C++ hash function 30 | std::hash hasher; 31 | 32 | // Process each batch item in parallel using at::parallel_for 33 | AT_DISPATCH_FLOATING_TYPES(scores.scalar_type(), "approx_topk_count_min_sketch", [&] 34 | { 35 | auto scores_accessor = scores.accessor(); 36 | auto threshold_accessor = threshold.accessor(); 37 | 38 | // Parallel processing over batch dimension 39 | // Use grain_size of 1 for fine-grained parallelism 40 | at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) { 41 | for (int64_t batch_idx = start; batch_idx < end; ++batch_idx) { 42 | // Initialize sketches with negative infinity (thread-local) 43 | std::vector> sketches(num_sketches, 44 | std::vector(sketch_width, -std::numeric_limits::infinity())); 45 | 46 | // Update sketches with maximum values at hash positions 47 | for (int sketch_idx = 0; sketch_idx < num_sketches; ++sketch_idx) { 48 | for (int64_t feature_idx = 0; feature_idx < feature_size; ++feature_idx) { 49 | // Use different hash functions for each sketch by combining with sketch_idx 50 | int64_t combined_key = sketch_idx * feature_size + feature_idx; 51 | int64_t hash_pos = hasher(combined_key) % sketch_width; 52 | 53 | scalar_t value = scores_accessor[batch_idx][feature_idx]; 54 | sketches[sketch_idx][hash_pos] = std::max(sketches[sketch_idx][hash_pos], value); 55 | } 56 | } 57 | 58 | // Collect all sketch values (thread-local) 59 | std::vector all_sketch_values; 60 | for (const auto& sketch : sketches) { 61 | for (scalar_t val : sketch) { 62 | if (val != -std::numeric_limits::infinity()) { 63 | all_sketch_values.push_back(val); 64 | } 65 | } 66 | } 67 | 68 | if (!all_sketch_values.empty()) { 69 | // Find approximate threshold 70 | int64_t sketch_k = std::max(1L, static_cast(k * all_sketch_values.size() / feature_size)); 71 | sketch_k = std::min(sketch_k, static_cast(all_sketch_values.size())); 72 | 73 | std::nth_element(all_sketch_values.begin(), 74 | all_sketch_values.begin() + sketch_k - 1, 75 | all_sketch_values.end(), 76 | std::greater()); 77 | 78 | // Apply adjustment factor for approximation error 79 | scalar_t adjustment_factor = 0.9; 80 | threshold_accessor[batch_idx][0] = all_sketch_values[sketch_k - 1] * adjustment_factor; 81 | } else { 82 | threshold_accessor[batch_idx][0] = 0.0; 83 | } 84 | } 85 | }); }); 86 | 87 | return threshold; 88 | } -------------------------------------------------------------------------------- /sparse_transformers/csrc/sparse_mlp_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | // Forward declarations with timing buffer 9 | template 10 | __global__ void sparse_mlp_combined_cuda_kernel( 11 | const scalar_t* __restrict__ input, 12 | const scalar_t* __restrict__ concat_weight, 13 | scalar_t* __restrict__ combined_buffer, 14 | const int batch_size, 15 | const int hidden_size, 16 | const int intermediate_size); 17 | 18 | template 19 | __global__ void sparse_mlp_output_cuda_kernel( 20 | const scalar_t* __restrict__ combined_buffer, 21 | const scalar_t* __restrict__ active_down_weight, 22 | scalar_t* __restrict__ output, 23 | const int batch_size, 24 | const int hidden_size, 25 | const int intermediate_size); 26 | 27 | // First kernel for float 28 | template <> 29 | __global__ void sparse_mlp_combined_cuda_kernel( 30 | const float* __restrict__ input, 31 | const float* __restrict__ concat_weight, 32 | float* __restrict__ combined_buffer, 33 | const int batch_size, 34 | const int hidden_size, 35 | const int intermediate_size) { 36 | 37 | const int batch_idx = threadIdx.z + blockIdx.z * blockDim.z; 38 | const int intermediate_idx = threadIdx.y + blockIdx.y * blockDim.y; 39 | const int hidden_idx = threadIdx.x + blockIdx.x * blockDim.x; 40 | 41 | if (batch_idx >= batch_size || intermediate_idx >= intermediate_size || hidden_idx >= hidden_size) 42 | return; 43 | 44 | // Get batch pointers 45 | const float* batch_input = input + batch_idx * hidden_size; 46 | float* batch_combined = combined_buffer + batch_idx * intermediate_size * 2; 47 | 48 | // Compute sum for this thread 49 | float sum = batch_input[hidden_idx] * concat_weight[intermediate_idx * hidden_size + hidden_idx]; 50 | 51 | // Warp reduction 52 | for (int offset = blockDim.x / 2; offset > 0; offset /= 2) { 53 | sum += __shfl_down_sync(0xffffffff, sum, offset); 54 | } 55 | 56 | // Atomic add to global combined buffer 57 | if (threadIdx.x == 0) { 58 | atomicAdd(&batch_combined[intermediate_idx], sum); 59 | } 60 | } 61 | 62 | // Second kernel: compute output using combined values 63 | template <> 64 | __global__ void sparse_mlp_output_cuda_kernel( 65 | const float* __restrict__ combined_buffer, 66 | const float* __restrict__ active_down_weight, 67 | float* __restrict__ output, 68 | const int batch_size, 69 | const int hidden_size, 70 | const int intermediate_size) { 71 | 72 | const int batch_idx = threadIdx.z + blockIdx.z * blockDim.z; 73 | const int intermediate_idx = threadIdx.y + blockIdx.y * blockDim.y; 74 | const int hidden_idx = threadIdx.x + blockIdx.x * blockDim.x; 75 | 76 | if (batch_idx >= batch_size || intermediate_idx >= intermediate_size || hidden_idx >= hidden_size) 77 | return; 78 | 79 | // Get batch pointers 80 | const float* batch_combined = combined_buffer + batch_idx * intermediate_size * 2; 81 | 82 | const float gate_val = batch_combined[intermediate_idx]; 83 | const float gate = 1.0f / (1.0f + expf(-gate_val)); 84 | const float up = batch_combined[intermediate_idx + intermediate_size]; 85 | const float down = active_down_weight[hidden_idx * intermediate_size + intermediate_idx]; 86 | const float val = gate * up * down; 87 | atomicAdd(&output[batch_idx * hidden_size + hidden_idx], val); 88 | } 89 | 90 | // First kernel for double 91 | template <> 92 | __global__ void sparse_mlp_combined_cuda_kernel( 93 | const double* __restrict__ input, 94 | const double* __restrict__ concat_weight, 95 | double* __restrict__ combined_buffer, 96 | const int batch_size, 97 | const int hidden_size, 98 | const int intermediate_size) { 99 | 100 | const int batch_idx = threadIdx.z + blockIdx.z * blockDim.z; 101 | const int intermediate_idx = threadIdx.y + blockIdx.y * blockDim.y; 102 | const int hidden_idx = threadIdx.x + blockIdx.x * blockDim.x; 103 | 104 | if (batch_idx >= batch_size || intermediate_idx >= intermediate_size || hidden_idx >= hidden_size) 105 | return; 106 | 107 | // Get batch pointers 108 | const double* batch_input = input + batch_idx * hidden_size; 109 | double* batch_combined = combined_buffer + batch_idx * intermediate_size * 2; 110 | 111 | // Compute sum for this thread 112 | double sum = batch_input[hidden_idx] * concat_weight[intermediate_idx * hidden_size + hidden_idx]; 113 | 114 | // Warp reduction 115 | for (int offset = blockDim.x / 2; offset > 0; offset /= 2) { 116 | sum += __shfl_down_sync(0xffffffff, sum, offset); 117 | } 118 | 119 | // Atomic add to global combined buffer 120 | if (threadIdx.x == 0) { 121 | atomicAdd(&batch_combined[batch_idx * intermediate_size*2 + intermediate_idx], sum); 122 | } 123 | } 124 | 125 | // Second kernel for double 126 | template <> 127 | __global__ void sparse_mlp_output_cuda_kernel( 128 | const double* __restrict__ combined_buffer, 129 | const double* __restrict__ active_down_weight, 130 | double* __restrict__ output, 131 | const int batch_size, 132 | const int hidden_size, 133 | const int intermediate_size) { 134 | 135 | const int batch_idx = threadIdx.z + blockIdx.z * blockDim.z; 136 | const int intermediate_idx = threadIdx.y + blockIdx.y * blockDim.y; 137 | const int hidden_idx = threadIdx.x + blockIdx.x * blockDim.x; 138 | 139 | if (batch_idx >= batch_size || intermediate_idx >= intermediate_size || hidden_idx >= hidden_size) 140 | return; 141 | 142 | // Get batch pointers 143 | const double* batch_combined = combined_buffer + batch_idx * intermediate_size * 2; 144 | 145 | const double gate_val = batch_combined[intermediate_idx]; 146 | const double gate = 1.0 / (1.0 + exp(-gate_val)); 147 | const double up = batch_combined[intermediate_idx + intermediate_size]; 148 | const double down = active_down_weight[hidden_idx * intermediate_size + intermediate_idx]; 149 | const double val = gate * up * down; 150 | atomicAdd(&output[batch_idx * hidden_size + hidden_idx], val); 151 | } 152 | 153 | // First kernel for half precision 154 | template <> 155 | __global__ void sparse_mlp_combined_cuda_kernel( 156 | const at::Half* __restrict__ input, 157 | const at::Half* __restrict__ concat_weight, 158 | at::Half* __restrict__ combined_buffer, 159 | const int batch_size, 160 | const int hidden_size, 161 | const int intermediate_size) { 162 | 163 | const int tid = threadIdx.x; 164 | const int hidden_idx = blockIdx.x * blockDim.x + tid; 165 | const int intermediate_idx = blockIdx.y * 16; 166 | const int batch_idx = blockIdx.z; 167 | const int lane_id = tid % 16; 168 | 169 | if (hidden_idx >= hidden_size || 2*intermediate_idx >= intermediate_size) return; 170 | 171 | __shared__ __half2 warp_sums[32]; 172 | // Get batch pointers with proper alignment 173 | const __half2* batch_input = reinterpret_cast(input) + batch_idx * hidden_size/2; 174 | __half2* batch_combined = reinterpret_cast<__half2*>(combined_buffer) + batch_idx * intermediate_size; 175 | const __half2* weight_ptr = reinterpret_cast(concat_weight); 176 | __half2 input_pair = batch_input[hidden_idx]; 177 | 178 | // Process warp-sized chunk of intermediate dimension 179 | #pragma unroll 8 180 | for (int i = 0; i < 16 && intermediate_idx + i*2 < intermediate_size; i+=2) { 181 | // Multiply both pairs at once 182 | __half2 sum = __hmul2(input_pair, weight_ptr[(intermediate_idx + i) * hidden_size/2 + hidden_idx]); 183 | __half2 sum2 = __hmul2(input_pair, weight_ptr[(intermediate_idx + i + 1) * hidden_size/2 + hidden_idx]); 184 | __half2 sum3 = __hmul2(input_pair, weight_ptr[(intermediate_idx + i + intermediate_size/2) * hidden_size/2 + hidden_idx]); 185 | __half2 sum4 = __hmul2(input_pair, weight_ptr[(intermediate_idx + i + intermediate_size/2 + 1) * hidden_size/2 + hidden_idx]); 186 | 187 | // Optimized warp reduction using butterfly pattern with half2 188 | #pragma unroll 189 | for (int mask = blockDim.x / 2; mask > 0; mask >>= 1) { 190 | sum = __hadd2(sum, __shfl_xor_sync(0xffffffff, sum, mask)); 191 | sum2 = __hadd2(sum2, __shfl_xor_sync(0xffffffff, sum2, mask)); 192 | sum3 = __hadd2(sum3, __shfl_xor_sync(0xffffffff, sum3, mask)); 193 | sum4 = __hadd2(sum4, __shfl_xor_sync(0xffffffff, sum4, mask)); 194 | } 195 | 196 | // Store results to shared memory 197 | if (tid == 0) { 198 | warp_sums[i] = sum; 199 | warp_sums[i+1] = sum2; 200 | warp_sums[i+16] = sum3; 201 | warp_sums[i+17] = sum4; 202 | } 203 | } 204 | 205 | __syncwarp(); 206 | 207 | // Have first warp do the atomic adds 208 | if (tid < 16 && intermediate_idx + lane_id < intermediate_size) { 209 | atomicAdd(&batch_combined[intermediate_idx + lane_id], warp_sums[lane_id]); 210 | atomicAdd(&batch_combined[intermediate_idx + intermediate_size/2 + lane_id], warp_sums[lane_id+32]); 211 | } 212 | } 213 | 214 | // Second kernel for half precision 215 | template <> 216 | __global__ void sparse_mlp_output_cuda_kernel( 217 | const at::Half* __restrict__ combined_buffer, 218 | const at::Half* __restrict__ active_down_weight, 219 | at::Half* __restrict__ output, 220 | const int batch_size, 221 | const int hidden_size, 222 | const int intermediate_size) { 223 | 224 | const int tid = threadIdx.x; 225 | const int intermediate_idx = blockIdx.x * blockDim.x + tid; 226 | const int hidden_idx = blockIdx.y * 16; 227 | const int batch_idx = blockIdx.z; 228 | const int lane_id = tid % 16; 229 | 230 | if (2*intermediate_idx >= intermediate_size) return; 231 | 232 | // Shared memory for partial sums and intermediate values 233 | __shared__ __half2 shared_sums[16]; 234 | __half2 gate_up_cache; // Cache gate/up values for reuse 235 | 236 | // Get batch pointers with proper alignment 237 | const __half2* batch_combined2 = reinterpret_cast(combined_buffer) + 238 | batch_idx * intermediate_size; 239 | const __half2* down_ptr2 = reinterpret_cast(active_down_weight); 240 | __half2* out_ptr2 = reinterpret_cast<__half2*>(output) + (batch_idx * (hidden_size) + hidden_idx)/2; 241 | 242 | // Load and process gate/up values - cache in shared memory 243 | __half2 combined = batch_combined2[intermediate_idx]; 244 | float2 gate_val = __half22float2(combined); 245 | __half2 gate = __float2half2_rn(1.0f / (1.0f + expf(-gate_val.x))); 246 | __half2 up = batch_combined2[intermediate_idx+intermediate_size/2]; 247 | gate_up_cache = __hmul2(gate, up); 248 | 249 | // Process 4 elements per iteration using 2x half2 250 | #pragma unroll 8 251 | for (int i = 0; i < 16 && hidden_idx + i*2 < hidden_size; i += 2) { 252 | // Load two pairs of down weights 253 | __half2 down1 = down_ptr2[(hidden_idx + i) * (intermediate_size/2) + intermediate_idx]; 254 | __half2 down2 = down_ptr2[(hidden_idx + i + 1) * (intermediate_size/2) + intermediate_idx]; 255 | 256 | // Multiply with cached gate_up values 257 | __half2 sum1 = __hmul2(gate_up_cache, down1); 258 | __half2 sum2 = __hmul2(gate_up_cache, down2); 259 | 260 | // Warp reduction for both pairs 261 | #pragma unroll 262 | for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) { 263 | sum1 = __hadd2(sum1, __shfl_xor_sync(0xffffffff, sum1, offset)); 264 | sum2 = __hadd2(sum2, __shfl_xor_sync(0xffffffff, sum2, offset)); 265 | } 266 | 267 | // Store results 268 | if (tid == 0) { 269 | shared_sums[i] = sum1; 270 | shared_sums[i+1] = sum2; 271 | } 272 | } 273 | 274 | __syncwarp(); 275 | 276 | // Coalesced writes to global memory - 4 elements at a time 277 | if (tid < 8 && hidden_idx + lane_id*2 < hidden_size) { 278 | atomicAdd(&out_ptr2[lane_id*2], shared_sums[lane_id*2]); 279 | atomicAdd(&out_ptr2[lane_id*2+1], shared_sums[lane_id*2+1]); 280 | } 281 | } 282 | 283 | // Main CUDA implementation 284 | torch::Tensor sparse_mlp_forward_cuda( 285 | const torch::Tensor& input, 286 | const torch::Tensor& concat_weight, 287 | const torch::Tensor& active_down_weight, 288 | torch::Tensor& down_proj_buffer, 289 | torch::Tensor& combined_proj_buffer, 290 | const std::string& activation_fn) { 291 | const auto batch_size = input.size(0); 292 | const auto hidden_size = input.size(1); 293 | const auto intermediate_size = concat_weight.size(0) / 2; 294 | 295 | const int threads_per_block = 256; 296 | const int blocks_x = (hidden_size + threads_per_block - 1) / (2*threads_per_block); 297 | 298 | dim3 grid(blocks_x, 299 | (intermediate_size + 15) / 16, // Group by warps 300 | batch_size); 301 | dim3 block(threads_per_block, 1, 1); 302 | 303 | // Get current CUDA stream 304 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.device().index()); 305 | 306 | // Launch first kernel with timing buffer 307 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "sparse_mlp_combined_cuda", [&] { 308 | sparse_mlp_combined_cuda_kernel<<>>( 309 | input.data_ptr(), 310 | concat_weight.data_ptr(), 311 | combined_proj_buffer.data_ptr(), 312 | batch_size, 313 | hidden_size, 314 | intermediate_size 315 | ); 316 | }); 317 | 318 | const int blocks_intermediate = (intermediate_size + threads_per_block - 1) / (2*threads_per_block); 319 | 320 | dim3 grid2(blocks_intermediate, 321 | (hidden_size + 15) / 16, // Group by warps 322 | batch_size); 323 | dim3 block2(threads_per_block, 1, 1); 324 | cudaStreamSynchronize(stream); 325 | 326 | // Launch second kernel with timing buffer 327 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "sparse_mlp_output_cuda", [&] { 328 | sparse_mlp_output_cuda_kernel<<>>( 329 | combined_proj_buffer.data_ptr(), 330 | active_down_weight.data_ptr(), 331 | down_proj_buffer.data_ptr(), 332 | batch_size, 333 | hidden_size, 334 | intermediate_size 335 | ); 336 | }); 337 | 338 | return down_proj_buffer; 339 | } -------------------------------------------------------------------------------- /sparse_transformers/csrc/sparse_mlp_op.cpp: -------------------------------------------------------------------------------- 1 | // For TorchScript support 2 | #include 3 | 4 | // For PyTorch C++ extension support 5 | #include 6 | 7 | // For tensor operations 8 | #include 9 | 10 | // For PyTorch's OpenMP wrapper 11 | #include 12 | 13 | // Add pybind11 and namespace 14 | #include 15 | namespace py = pybind11; 16 | 17 | // Add required headers 18 | #include 19 | #include 20 | #include 21 | 22 | // Add device check utilities 23 | #include 24 | 25 | // Add custom headers 26 | #include "weight_cache.h" 27 | #include "approx_topk.h" 28 | 29 | // Forward declarations of CPU/CUDA implementations 30 | torch::Tensor sparse_mlp_forward_cpu( 31 | const torch::Tensor &input, 32 | const torch::Tensor &concat_weight, 33 | const torch::Tensor &active_down_weight, 34 | torch::Tensor &down_proj_buffer, 35 | torch::Tensor &combined_proj_buffer, 36 | const std::string &activation_fn); 37 | 38 | #ifdef WITH_CUDA 39 | torch::Tensor sparse_mlp_forward_cuda( 40 | const torch::Tensor &input, 41 | const torch::Tensor &concat_weight, 42 | const torch::Tensor &active_down_weight, 43 | torch::Tensor &down_proj_buffer, 44 | torch::Tensor &combined_proj_buffer, 45 | const std::string &activation_fn); 46 | #endif 47 | 48 | // Main dispatch function 49 | torch::Tensor sparse_mlp_forward( 50 | const torch::Tensor &input, 51 | const torch::Tensor &concat_weight, 52 | const torch::Tensor &active_down_weight, 53 | torch::Tensor &down_proj_buffer, 54 | torch::Tensor &combined_proj_buffer, 55 | const std::string &activation_fn) 56 | { 57 | 58 | // Check if input is on CUDA and dispatch accordingly 59 | if (input.is_cuda()) 60 | { 61 | #ifdef WITH_CUDA 62 | return sparse_mlp_forward_cuda(input, concat_weight, active_down_weight, down_proj_buffer, combined_proj_buffer, activation_fn); 63 | #else 64 | AT_ERROR("CUDA not available - cannot run on GPU"); 65 | #endif 66 | } 67 | else 68 | { 69 | return sparse_mlp_forward_cpu(input, concat_weight, active_down_weight, down_proj_buffer, combined_proj_buffer, activation_fn); 70 | } 71 | } 72 | 73 | // CPU implementation 74 | torch::Tensor sparse_mlp_forward_cpu( 75 | const torch::Tensor &input, 76 | const torch::Tensor &concat_weight, 77 | const torch::Tensor &active_down_weight, 78 | torch::Tensor &down_proj_buffer, 79 | torch::Tensor &combined_proj_buffer, 80 | const std::string &activation_fn) 81 | { 82 | 83 | const auto batch_size = input.size(0); 84 | const auto hidden_size = input.size(1); 85 | 86 | // Ensure output buffer is correctly sized 87 | if (down_proj_buffer.size(0) != batch_size) 88 | { 89 | down_proj_buffer.resize_({batch_size, hidden_size}); 90 | } 91 | if (combined_proj_buffer.size(0) != batch_size) 92 | { 93 | combined_proj_buffer.resize_({batch_size, 2 * int(concat_weight.size(0))}); 94 | } 95 | 96 | // Process each batch item in parallel 97 | at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) 98 | { 99 | for (int64_t batch_idx = start; batch_idx < end; batch_idx++) { 100 | int64_t gate_size = concat_weight.size(0) / 2; 101 | auto x_batch = input[batch_idx].unsqueeze(0).detach(); 102 | 103 | // Single matmul for both gate and up projections 104 | auto proj_view = combined_proj_buffer[batch_idx].unsqueeze(0).narrow(1, 0, concat_weight.size(0)); 105 | torch::matmul_out(proj_view, x_batch, concat_weight.t()); 106 | 107 | // Split result into gate and up projections 108 | auto gate_proj = proj_view.narrow(1, 0, gate_size); 109 | auto up_proj = proj_view.narrow(1, gate_size, gate_size); 110 | 111 | // Apply activations 112 | gate_proj.mul_(torch::sigmoid(gate_proj)); 113 | gate_proj.mul_(up_proj); 114 | 115 | // Final projection 116 | down_proj_buffer[batch_idx] = torch::matmul(gate_proj, active_down_weight.t())[0]; 117 | } }); 118 | return down_proj_buffer; 119 | } 120 | 121 | // Register TorchScript custom classes and operators 122 | TORCH_LIBRARY(sparse_mlp, m) 123 | { 124 | // Register the optimized weight cache 125 | m.class_("WeightCache") 126 | .def(torch::init()) 127 | .def("update_active_weights", &WeightCache::update_active_weights) 128 | .def("get_concat_weight", &WeightCache::get_concat_weight) 129 | .def("get_active_down_weight", &WeightCache::get_active_down_weight); 130 | 131 | // Register sparse MLP operator 132 | m.def("forward", sparse_mlp_forward); 133 | 134 | // Register Count-Min Sketch approximate top-k threshold operator 135 | m.def("approx_topk_threshold", approx_topk_threshold); 136 | } -------------------------------------------------------------------------------- /sparse_transformers/csrc/weight_cache.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include // For SIMD operations 12 | 13 | class WeightCache : public torch::CustomClassHolder 14 | { 15 | private: 16 | // Define deleter as a struct to avoid std::function overhead 17 | struct AlignedDeleter 18 | { 19 | void operator()(float *ptr) const 20 | { 21 | free(ptr); 22 | } 23 | }; 24 | 25 | bool is_initialized = false; 26 | 27 | // Memory pools for all weight data (cache-aligned) 28 | std::unique_ptr gate_memory_pool; 29 | std::unique_ptr up_memory_pool; 30 | std::unique_ptr down_memory_pool_transposed; // Store transposed for fast row access 31 | 32 | // Matrix dimensions 33 | int64_t hidden_dim = 0; 34 | int64_t sparse_dim = 0; 35 | int64_t gate_row_size = 0; 36 | int64_t up_row_size = 0; 37 | int64_t down_row_size = 0; // This becomes hidden_dim after transpose 38 | 39 | torch::ScalarType dtype; 40 | torch::Device current_device = torch::kCPU; 41 | 42 | // Currently active indices (maintained in order for contiguous access) 43 | std::vector active_indices; 44 | 45 | // Mapping from active_index to position in active_indices for O(1) lookup 46 | std::unordered_map index_to_position; 47 | 48 | // Contiguous buffers for active data (always packed) 49 | std::unique_ptr active_gate_buffer; 50 | std::unique_ptr active_up_buffer; 51 | std::unique_ptr active_down_buffer; 52 | 53 | // Current mask for differential updates 54 | torch::Tensor current_mask; 55 | 56 | // Cached active weight tensors - use from_blob to reference our buffers directly 57 | torch::Tensor active_weights_cache; 58 | torch::Tensor active_downs_cache; 59 | bool cache_valid = false; 60 | 61 | // Max expected active indices (dynamic based on intermediate_size) 62 | size_t max_active_indices = 0; 63 | 64 | // Cache-aligned memory allocation 65 | static void *aligned_alloc_wrapper(size_t size) 66 | { 67 | void *ptr = nullptr; 68 | if (posix_memalign(&ptr, 64, size) != 0) 69 | { // 64-byte alignment for cache lines 70 | throw std::bad_alloc(); 71 | } 72 | return ptr; 73 | } 74 | 75 | // Find differential changes between masks using PyTorch operations 76 | struct MaskDiff 77 | { 78 | std::vector added_indices; 79 | std::vector removed_indices; 80 | }; 81 | 82 | MaskDiff compute_mask_diff(const torch::Tensor &old_mask, const torch::Tensor &new_mask) 83 | { 84 | MaskDiff diff; 85 | 86 | // Use PyTorch operations for efficient mask comparison 87 | auto added_mask = new_mask & (~old_mask); // new & ~old = added 88 | auto removed_mask = old_mask & (~new_mask); // old & ~new = removed 89 | 90 | // Get indices of added and removed elements 91 | auto added_indices_tensor = torch::nonzero(added_mask).squeeze(-1); 92 | auto removed_indices_tensor = torch::nonzero(removed_mask).squeeze(-1); 93 | 94 | // Convert to std::vector 95 | if (added_indices_tensor.numel() > 0) 96 | { 97 | auto added_data = added_indices_tensor.data_ptr(); 98 | diff.added_indices.assign(added_data, added_data + added_indices_tensor.numel()); 99 | } 100 | 101 | if (removed_indices_tensor.numel() > 0) 102 | { 103 | auto removed_data = removed_indices_tensor.data_ptr(); 104 | diff.removed_indices.assign(removed_data, removed_data + removed_indices_tensor.numel()); 105 | } 106 | 107 | return diff; 108 | } 109 | 110 | // Rebuild tensors using from_blob to reference our contiguous buffers 111 | void rebuild_tensor_views() 112 | { 113 | const size_t num_active = active_indices.size(); 114 | 115 | if (num_active == 0) 116 | { 117 | auto options = torch::TensorOptions().device(current_device).dtype(dtype); 118 | active_weights_cache = torch::empty({0, hidden_dim}, options); 119 | active_downs_cache = torch::empty({0, hidden_dim}, options); 120 | return; 121 | } 122 | 123 | // Create gate tensor directly from buffer 124 | auto gate_tensor = torch::from_blob(active_gate_buffer.get(), 125 | {static_cast(num_active), gate_row_size}, 126 | torch::TensorOptions().dtype(dtype)); 127 | 128 | // Create up tensor directly from buffer 129 | auto up_tensor = torch::from_blob(active_up_buffer.get(), 130 | {static_cast(num_active), up_row_size}, 131 | torch::TensorOptions().dtype(dtype)); 132 | 133 | // Create down tensor directly from buffer and transpose 134 | auto down_tensor_packed = torch::from_blob(active_down_buffer.get(), 135 | {static_cast(num_active), hidden_dim}, 136 | torch::TensorOptions().dtype(dtype)); 137 | auto down_tensor = down_tensor_packed.t(); // [hidden_dim, num_active] 138 | 139 | // Concatenate and move to target device 140 | active_weights_cache = torch::cat({gate_tensor, up_tensor}, 0).to(current_device); 141 | active_downs_cache = down_tensor.to(current_device); 142 | } 143 | 144 | public: 145 | WeightCache(const torch::Tensor &init_mask, int64_t hidden_size, 146 | const torch::Tensor &gate_weight, const torch::Tensor &up_weight, 147 | const torch::Tensor &down_weight) 148 | { 149 | init(init_mask, hidden_size, gate_weight, up_weight, down_weight); 150 | } 151 | 152 | void init(const torch::Tensor &init_mask, int64_t hidden_size, 153 | const torch::Tensor &gate_weight, const torch::Tensor &up_weight, 154 | const torch::Tensor &down_weight) 155 | { 156 | 157 | current_device = gate_weight.device(); 158 | dtype = gate_weight.scalar_type(); 159 | 160 | // Store dimensions 161 | hidden_dim = hidden_size; 162 | sparse_dim = gate_weight.size(0); 163 | max_active_indices = init_mask.sum().item(); 164 | gate_row_size = gate_weight.size(1); 165 | up_row_size = up_weight.size(1); 166 | down_row_size = hidden_dim; // After transpose: [intermediate_size, hidden_size] 167 | 168 | // Allocate cache-aligned memory pools 169 | const size_t gate_total_size = sparse_dim * gate_row_size; 170 | const size_t up_total_size = sparse_dim * up_row_size; 171 | const size_t down_total_size = sparse_dim * hidden_dim; // Transposed shape 172 | 173 | gate_memory_pool = std::unique_ptr( 174 | static_cast(aligned_alloc_wrapper(gate_total_size * sizeof(float)))); 175 | up_memory_pool = std::unique_ptr( 176 | static_cast(aligned_alloc_wrapper(up_total_size * sizeof(float)))); 177 | down_memory_pool_transposed = std::unique_ptr( 178 | static_cast(aligned_alloc_wrapper(down_total_size * sizeof(float)))); 179 | 180 | // Pre-allocate contiguous buffers for active weights 181 | active_gate_buffer = std::unique_ptr( 182 | static_cast(aligned_alloc_wrapper(max_active_indices * gate_row_size * sizeof(float)))); 183 | active_up_buffer = std::unique_ptr( 184 | static_cast(aligned_alloc_wrapper(max_active_indices * up_row_size * sizeof(float)))); 185 | active_down_buffer = std::unique_ptr( 186 | static_cast(aligned_alloc_wrapper(max_active_indices * hidden_dim * sizeof(float)))); 187 | 188 | // Initialize differential update tracking 189 | index_to_position.reserve(max_active_indices); 190 | 191 | // Copy weights to memory pools 192 | auto gate_cpu = gate_weight.to(torch::kCPU).contiguous(); 193 | auto up_cpu = up_weight.to(torch::kCPU).contiguous(); 194 | auto down_cpu = down_weight.to(torch::kCPU).contiguous(); 195 | 196 | // Copy gate and up weights directly (row-major format) 197 | std::memcpy(gate_memory_pool.get(), gate_cpu.data_ptr(), gate_total_size * sizeof(float)); 198 | std::memcpy(up_memory_pool.get(), up_cpu.data_ptr(), up_total_size * sizeof(float)); 199 | 200 | // Transpose down matrix during copy: [hidden_size, intermediate_size] -> [intermediate_size, hidden_size] 201 | auto down_data = down_cpu.data_ptr(); 202 | for (int64_t i = 0; i < sparse_dim; ++i) 203 | { 204 | for (int64_t j = 0; j < hidden_dim; ++j) 205 | { 206 | down_memory_pool_transposed[i * hidden_dim + j] = down_data[j * sparse_dim + i]; 207 | } 208 | } 209 | 210 | is_initialized = true; 211 | current_mask = torch::zeros(sparse_dim, torch::TensorOptions().dtype(torch::kBool).device(current_device)); 212 | 213 | // Initialize with mask 214 | update_active_weights(init_mask); 215 | } 216 | 217 | void update_active_weights(const torch::Tensor &mask) 218 | { 219 | if (!is_initialized) 220 | return; 221 | 222 | // Compute diff with normalization handled internally 223 | auto diff = compute_mask_diff(current_mask, mask); 224 | 225 | // Early exit if no changes - avoid all processing work! 226 | if (diff.added_indices.empty() && diff.removed_indices.empty()) 227 | { 228 | return; 229 | } 230 | 231 | // Optimized single-pass removal+addition logic 232 | const size_t num_removals = diff.removed_indices.size(); 233 | const size_t num_additions = diff.added_indices.size(); 234 | const size_t pairs_to_process = std::min(num_removals, num_additions); 235 | 236 | // First pass: Pair removals with additions for direct replacement (most cache-efficient) 237 | for (size_t i = 0; i < pairs_to_process; ++i) 238 | { 239 | int64_t removed_idx = diff.removed_indices[i]; 240 | int64_t added_idx = diff.added_indices[i]; 241 | 242 | auto it = index_to_position.find(removed_idx); 243 | if (it != index_to_position.end()) 244 | { 245 | size_t pos = it->second; 246 | 247 | // Direct replacement - copy new data over old position (single memcpy per matrix!) 248 | std::memcpy(active_gate_buffer.get() + pos * gate_row_size, 249 | gate_memory_pool.get() + added_idx * gate_row_size, 250 | gate_row_size * sizeof(float)); 251 | 252 | std::memcpy(active_up_buffer.get() + pos * up_row_size, 253 | up_memory_pool.get() + added_idx * up_row_size, 254 | up_row_size * sizeof(float)); 255 | 256 | std::memcpy(active_down_buffer.get() + pos * hidden_dim, 257 | down_memory_pool_transposed.get() + added_idx * hidden_dim, 258 | hidden_dim * sizeof(float)); 259 | 260 | // Update tracking - remove old, add new at same position 261 | index_to_position.erase(it); 262 | active_indices[pos] = added_idx; 263 | index_to_position[added_idx] = pos; 264 | } 265 | } 266 | 267 | // Handle remaining additions (if more additions than removals) 268 | for (size_t i = pairs_to_process; i < num_additions; ++i) 269 | { 270 | int64_t added_idx = diff.added_indices[i]; 271 | size_t new_pos = active_indices.size(); 272 | 273 | if (new_pos >= max_active_indices) 274 | { 275 | continue; // Skip if buffer full 276 | } 277 | 278 | // Append to end 279 | std::memcpy(active_gate_buffer.get() + new_pos * gate_row_size, 280 | gate_memory_pool.get() + added_idx * gate_row_size, 281 | gate_row_size * sizeof(float)); 282 | 283 | std::memcpy(active_up_buffer.get() + new_pos * up_row_size, 284 | up_memory_pool.get() + added_idx * up_row_size, 285 | up_row_size * sizeof(float)); 286 | 287 | std::memcpy(active_down_buffer.get() + new_pos * hidden_dim, 288 | down_memory_pool_transposed.get() + added_idx * hidden_dim, 289 | hidden_dim * sizeof(float)); 290 | 291 | // Update tracking 292 | active_indices.push_back(added_idx); 293 | index_to_position[added_idx] = new_pos; 294 | } 295 | 296 | // Handle remaining removals (if more removals than additions) 297 | for (size_t i = pairs_to_process; i < num_removals; ++i) 298 | { 299 | int64_t removed_idx = diff.removed_indices[i]; 300 | auto it = index_to_position.find(removed_idx); 301 | if (it != index_to_position.end()) 302 | { 303 | size_t pos_to_remove = it->second; 304 | size_t last_pos = active_indices.size() - 1; 305 | 306 | if (pos_to_remove != last_pos) 307 | { 308 | // Move last element to fill gap 309 | int64_t last_idx = active_indices[last_pos]; 310 | 311 | std::memcpy(active_gate_buffer.get() + pos_to_remove * gate_row_size, 312 | active_gate_buffer.get() + last_pos * gate_row_size, 313 | gate_row_size * sizeof(float)); 314 | 315 | std::memcpy(active_up_buffer.get() + pos_to_remove * up_row_size, 316 | active_up_buffer.get() + last_pos * up_row_size, 317 | up_row_size * sizeof(float)); 318 | 319 | std::memcpy(active_down_buffer.get() + pos_to_remove * hidden_dim, 320 | active_down_buffer.get() + last_pos * hidden_dim, 321 | hidden_dim * sizeof(float)); 322 | 323 | // Update tracking 324 | active_indices[pos_to_remove] = last_idx; 325 | index_to_position[last_idx] = pos_to_remove; 326 | } 327 | 328 | // Remove last element 329 | active_indices.pop_back(); 330 | index_to_position.erase(it); 331 | } 332 | } 333 | 334 | // Rebuild tensor views using from_blob (no copying!) 335 | rebuild_tensor_views(); 336 | cache_valid = true; 337 | current_mask = mask.clone(); 338 | } 339 | 340 | // Getters remain the same 341 | torch::Tensor get_concat_weight() const 342 | { 343 | TORCH_CHECK(cache_valid, "Cache is not valid"); 344 | return active_weights_cache; 345 | } 346 | 347 | torch::Tensor get_active_down_weight() const 348 | { 349 | TORCH_CHECK(cache_valid, "Cache is not valid"); 350 | return active_downs_cache; 351 | } 352 | 353 | size_t get_num_active() const 354 | { 355 | return active_indices.size(); 356 | } 357 | 358 | // Destructor - no manual cleanup needed with smart pointers 359 | ~WeightCache() = default; 360 | }; -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import llama 2 | from . import utilities 3 | from . import trainer -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import llama 2 | # from . import dia -------------------------------------------------------------------------------- /src/models/llama/__init__.py: -------------------------------------------------------------------------------- 1 | from . import configuration_llama_skip 2 | from . import modelling_llama_skip 3 | 4 | __all__ = [configuration_llama_skip, modelling_llama_skip] -------------------------------------------------------------------------------- /src/models/llama/configuration_llama_skip.py: -------------------------------------------------------------------------------- 1 | from transformers import LlamaConfig 2 | from optimum.utils import NormalizedTextConfig, MistralDummyPastKeyValuesGenerator, DummyTextInputGenerator 3 | import os 4 | from typing import Union 5 | from optimum.exporters.onnx.config import TextDecoderWithPositionIdsOnnxConfig 6 | 7 | 8 | class LlamaSkipConnectionConfig(LlamaConfig): 9 | model_type = "llama-skip" 10 | 11 | def __init__(self, 12 | sparsity: float, 13 | predictor_loss_type: str = "bce", 14 | predictor_temperature: float = 1.0, 15 | predictor_loss_alpha: float = 1.0, 16 | predictor_loss_weight: float = 0.1, 17 | use_optimized_weight_cache: bool = True, 18 | **kwargs): 19 | self._sparsity = sparsity 20 | self.predictor_loss_type = predictor_loss_type 21 | self.predictor_temperature = predictor_temperature 22 | self.predictor_loss_alpha = predictor_loss_alpha 23 | self.predictor_loss_weight = predictor_loss_weight 24 | self.use_optimized_weight_cache = use_optimized_weight_cache 25 | super().__init__(**kwargs) 26 | 27 | @property 28 | def sparsity(self): 29 | return self._sparsity 30 | 31 | @sparsity.setter 32 | def sparsity(self, value): 33 | self._sparsity = value 34 | 35 | @classmethod 36 | def from_json_file(cls, json_file: Union[str, os.PathLike]): 37 | """ 38 | Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters. 39 | 40 | Args: 41 | json_file (`str` or `os.PathLike`): 42 | Path to the JSON file containing the parameters. 43 | 44 | Returns: 45 | [`PretrainedConfig`]: The configuration object instantiated from that JSON file. 46 | 47 | """ 48 | config_dict = cls._dict_from_json_file(json_file) 49 | return cls(**config_dict) 50 | 51 | 52 | class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): 53 | DEFAULT_ONNX_OPSET = 14 # Llama now uses F.scaled_dot_product_attention by default for torch>=2.1.1. 54 | 55 | DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) 56 | DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator 57 | NORMALIZED_CONFIG_CLASS = NormalizedTextConfig -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import attr 4 | import numpy as np 5 | import torch 6 | from sklearn import metrics 7 | from tqdm import tqdm 8 | 9 | from src.utilities import registry 10 | from src.utilities.saver import Saver 11 | from src.utilities.logger import BaseLogger 12 | from src.utilities.random import Reproducible 13 | 14 | 15 | @attr.s 16 | class TrainConfig: 17 | eval_every_n = attr.ib(default=10000) 18 | report_every_n = attr.ib(default=10) 19 | save_every_n = attr.ib(default=2000) 20 | keep_every_n = attr.ib(default=10000) 21 | 22 | batch_size = attr.ib(default=32) 23 | eval_batch_size = attr.ib(default=128) 24 | num_epochs = attr.ib(default=-1) 25 | 26 | num_batches = attr.ib(default=-1) 27 | 28 | @num_batches.validator 29 | def check_only_one_declaration(instance, _, value): 30 | if instance.num_epochs > 0 & value > 0: 31 | raise ValueError( 32 | "only one out of num_epochs and num_batches must be declared!") 33 | 34 | num_eval_batches = attr.ib(default=-1) 35 | eval_on_train = attr.ib(default=False) 36 | eval_on_val = attr.ib(default=True) 37 | 38 | num_workers = attr.ib(default=0) 39 | pin_memory = attr.ib(default=True) 40 | log_gradients = attr.ib(default=False) 41 | 42 | 43 | class Trainer(Reproducible): 44 | def __init__( 45 | self, 46 | config_dict: Dict, 47 | logger: BaseLogger) -> None: 48 | 49 | super().__init__(config_dict) 50 | self.config_dict = config_dict 51 | 52 | self.train_config = TrainConfig(**config_dict["train"]["config"]) 53 | self.logger = logger 54 | modelCls = registry.lookup('model', config_dict["model"]) 55 | self.model_preproc = registry.instantiate( 56 | modelCls.Preproc, 57 | config_dict["model"]['preproc'], unused_keys=()) 58 | 59 | with self.model_random: 60 | # 1. Construct model 61 | self.model_preproc.load_data_description() 62 | self.model = registry.construct( 63 | 'model', self.config_dict["model"], 64 | preprocessor=self.model_preproc, 65 | unused_keys=('name', 'preproc') 66 | ) 67 | if torch.cuda.is_available(): 68 | self.model.cuda() 69 | 70 | self._data_loaders = {} 71 | self._scheduler = None 72 | 73 | with self.init_random: 74 | self.optimizer = registry.construct( 75 | 'optimizer', self.config_dict['trainer']['optimizer'], 76 | params=self.model.parameters()) 77 | self._saver = None 78 | 79 | def reset_loaders(self): 80 | self._data_loaders = {} 81 | 82 | @staticmethod 83 | def _yield_batches_from_epochs(loader, start_epoch): 84 | current_epoch = start_epoch 85 | while True: 86 | for batch in loader: 87 | yield batch, current_epoch 88 | current_epoch += 1 89 | 90 | @property 91 | def scheduler(self, optimizer, **kwargs): 92 | if self._scheduler is None: 93 | with self.init_random: 94 | self._scheduler = registry.construct( 95 | 'lr_scheduler', 96 | self.config_dict['trainer'].get( 97 | 'lr_scheduler', {'name': 'noop'}), 98 | optimizer=optimizer, **kwargs) 99 | return self._scheduler 100 | 101 | @property 102 | def saver(self): 103 | if self._saver is None: 104 | # 2. Restore model parameters 105 | self._saver = Saver( 106 | self.model, self.optimizer, 107 | keep_every_n=self.train_config.keep_every_n) 108 | return self._saver 109 | 110 | @property 111 | def data_loaders(self): 112 | if self._data_loaders: 113 | return self._data_loaders 114 | # TODO : FIX if not client_id will load whole dataset 115 | self.model_preproc.load() 116 | # 3. Get training data somewhere 117 | with self.data_random: 118 | train_data = self.model_preproc.dataset('train') 119 | train_data_loader = self.model_preproc.data_loader( 120 | train_data, 121 | batch_size=self.train_config.batch_size, 122 | num_workers=self.train_config.num_workers, 123 | pin_memory=self.train_config.pin_memory, 124 | persistent_workers=True, 125 | shuffle=True, 126 | drop_last=True) 127 | 128 | train_eval_data_loader = self.model_preproc.data_loader( 129 | train_data, 130 | pin_memory=self.train_config.pin_memory, 131 | num_workers=self.train_config.num_workers, 132 | persistent_workers=True, 133 | batch_size=self.train_config.eval_batch_size) 134 | 135 | val_data = self.model_preproc.dataset('val') 136 | val_data_loader = self.model_preproc.data_loader( 137 | val_data, 138 | num_workers=self.train_config.num_workers, 139 | pin_memory=self.train_config.pin_memory, 140 | persistent_workers=True, 141 | batch_size=self.train_config.eval_batch_size) 142 | self._data_loaders = { 143 | 'train': train_data_loader, 144 | 'train_eval': train_eval_data_loader, 145 | 'val': val_data_loader 146 | } 147 | 148 | @staticmethod 149 | def eval_model( 150 | model, 151 | loader, 152 | eval_section, 153 | logger, 154 | num_eval_batches=-1, 155 | best_acc_test=None, 156 | best_auc_test=None, 157 | step=-1): 158 | scores = [] 159 | targets = [] 160 | model.eval() 161 | total_len = num_eval_batches if num_eval_batches > 0 else len(loader) 162 | with torch.no_grad(): 163 | t_loader = tqdm(enumerate(loader), unit="batch", total=total_len) 164 | for i, testBatch in t_loader: 165 | # early exit if nbatches was set by the user and was exceeded 166 | if (num_eval_batches > 0) and (i >= num_eval_batches): 167 | break 168 | t_loader.set_description(f"Running {eval_section}") 169 | 170 | inputs, true_labels = testBatch 171 | 172 | # forward pass 173 | Z_test = model.get_scores(model(inputs)) 174 | 175 | S_test = Z_test.detach().cpu().numpy() # numpy array 176 | T_test = true_labels.detach().cpu().numpy() # numpy array 177 | 178 | scores.append(S_test) 179 | targets.append(T_test) 180 | 181 | model.train() 182 | scores = np.concatenate(scores, axis=0) 183 | targets = np.concatenate(targets, axis=0) 184 | metrics_dict = { 185 | "recall": lambda y_true, y_score: metrics.recall_score( 186 | y_true=y_true, y_pred=np.round(y_score) 187 | ), 188 | "precision": lambda y_true, y_score: metrics.precision_score( 189 | y_true=y_true, y_pred=np.round(y_score), zero_division=0.0 190 | ), 191 | "f1": lambda y_true, y_score: metrics.f1_score( 192 | y_true=y_true, y_pred=np.round(y_score) 193 | ), 194 | "ap": metrics.average_precision_score, 195 | "roc_auc": metrics.roc_auc_score, 196 | "accuracy": lambda y_true, y_score: metrics.accuracy_score( 197 | y_true=y_true, y_pred=np.round(y_score) 198 | ), 199 | } 200 | 201 | results = {} 202 | for metric_name, metric_function in metrics_dict.items(): 203 | results[metric_name] = metric_function(targets, scores) 204 | logger.add_scalar( 205 | eval_section + "/" + "mlperf-metrics/" + metric_name, 206 | results[metric_name], 207 | step, 208 | ) 209 | 210 | if (best_auc_test is not None) and\ 211 | (results["roc_auc"] > best_auc_test): 212 | best_auc_test = results["roc_auc"] 213 | best_acc_test = results["accuracy"] 214 | return True, results 215 | 216 | return False, results 217 | 218 | def test(self): 219 | results = {} 220 | if self.train_config.eval_on_train: 221 | _, results['train_metrics'] = self.eval_model( 222 | self.model, 223 | self.data_loaders['train_eval'], 224 | eval_section='train_eval', 225 | num_eval_batches=self.train_config.num_eval_batches, 226 | logger=self.logger, step=-1) 227 | 228 | if self.train_config.eval_on_val: 229 | _, results['test_metrics'] = self.eval_model( 230 | self.model, 231 | self.data_loaders['test'], 232 | eval_section='test', 233 | logger=self.logger, 234 | num_eval_batches=self.train_config.num_eval_batches, 235 | step=-1) 236 | return results 237 | 238 | def train(self, modeldir=None): 239 | last_step, current_epoch = self.saver.restore(modeldir) 240 | lr_scheduler = self.scheduler( 241 | self.optimizer, last_epoch=last_step) 242 | 243 | if self.train_config.num_batches > 0: 244 | total_train_len = self.train_config.num_batches 245 | else: 246 | total_train_len = len(self.data_loaders['train']) 247 | train_dl = self._yield_batches_from_epochs( 248 | self.data_loaders['train'], start_epoch=current_epoch) 249 | 250 | # 4. Start training loop 251 | with self.data_random: 252 | best_acc_test = 0 253 | best_auc_test = 0 254 | dummy_input = next(iter(train_dl))[0] 255 | self.logger.add_graph(self.model, dummy_input[0]) 256 | t_loader = tqdm(train_dl, unit='batch', 257 | total=total_train_len) 258 | for batch, current_epoch in t_loader: 259 | t_loader.set_description(f"Training Epoch {current_epoch}") 260 | 261 | # Quit if too long 262 | if self.train_config.num_batches > 0 and\ 263 | last_step >= self.train_config.num_batches: 264 | break 265 | if self.train_config.num_epochs > 0 and\ 266 | current_epoch >= self.train_config.num_epochs: 267 | break 268 | 269 | # Evaluate model 270 | if last_step % self.train_config.eval_every_n == 0: 271 | if self.train_config.eval_on_train: 272 | self.eval_model( 273 | self.model, 274 | self.data_loaders['train_eval'], 275 | 'train_eval', 276 | self.logger, 277 | self.train_config.num_eval_batches, 278 | step=last_step) 279 | 280 | if self.train_config.eval_on_val: 281 | if self.eval_model( 282 | self.model, 283 | self.data_loaders['val'], 284 | 'val', 285 | self.logger, 286 | self.train_config.num_eval_batches, 287 | best_acc_test=best_acc_test, 288 | best_auc_test=best_auc_test, 289 | step=last_step)[1]: 290 | self.saver.save(modeldir, last_step, 291 | current_epoch, is_best=True) 292 | 293 | # Compute and apply gradient 294 | with self.model_random: 295 | input, true_label = batch 296 | output = self.model(input) 297 | loss = self.model.loss(output, true_label) 298 | self.optimizer.zero_grad() 299 | loss.backward() 300 | self.optimizer.step() 301 | lr_scheduler.step() 302 | 303 | # Report metrics 304 | if last_step % self.train_config.report_every_n == 0: 305 | t_loader.set_postfix({'loss': loss.item()}) 306 | self.logger.add_scalar( 307 | 'train/loss', loss.item(), global_step=last_step) 308 | self.logger.add_scalar( 309 | 'train/lr', lr_scheduler.last_lr[0], 310 | global_step=last_step) 311 | if self.train_config.log_gradients: 312 | self.logger.log_gradients(self.model, last_step) 313 | 314 | last_step += 1 315 | # Run saver 316 | if last_step % self.train_config.save_every_n == 0: 317 | self.saver.save(modeldir, last_step, current_epoch) 318 | return 319 | -------------------------------------------------------------------------------- /src/utilities/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cuda_utils 2 | from . import logger 3 | from . import random 4 | from . import registry 5 | from . import saver -------------------------------------------------------------------------------- /src/utilities/cuda_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import socket 3 | import logging 4 | 5 | 6 | def map_to_cuda(args, device=None, **kwargs): 7 | if isinstance(args, (list, tuple)): 8 | return [map_to_cuda(arg, device, **kwargs) for arg in args] 9 | elif isinstance(args, dict): 10 | return {k: map_to_cuda(v, device, **kwargs) for k, v in args.items()} 11 | elif isinstance(args, torch.Tensor): 12 | return args.cuda(device, **kwargs) 13 | else: 14 | raise TypeError("unsupported type for cuda migration") 15 | 16 | 17 | def map_to_list(model_params): 18 | for k in model_params.keys(): 19 | model_params[k] = model_params[k].detach().numpy().tolist() 20 | return model_params 21 | 22 | 23 | def mapping_processes_to_gpus(gpu_config, process_id, worker_number): 24 | if gpu_config == None: 25 | device = torch.device("cpu") 26 | logging.info(device) 27 | # return gpu_util_map[process_id][1] 28 | return device 29 | else: 30 | logging.info(gpu_config) 31 | gpu_util_map = {} 32 | i = 0 33 | for host, gpus_util_map_host in gpu_config.items(): 34 | for gpu_j, num_process_on_gpu in enumerate(gpus_util_map_host): 35 | for _ in range(num_process_on_gpu): 36 | gpu_util_map[i] = (host, gpu_j) 37 | i += 1 38 | logging.info("Process: %d" % (process_id)) 39 | logging.info("host: %s" % (gpu_util_map[process_id][0])) 40 | logging.info("gethostname: %s" % (socket.gethostname())) 41 | logging.info("gpu: %d" % (gpu_util_map[process_id][1])) 42 | assert i == worker_number 43 | 44 | device = torch.device( 45 | "cuda:" + str(gpu_util_map[process_id][1]) 46 | if torch.cuda.is_available() else "cpu") 47 | logging.info(device) 48 | # return gpu_util_map[process_id][1] 49 | return device -------------------------------------------------------------------------------- /src/utilities/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC, abstractmethod 3 | from time import time 4 | 5 | 6 | class BaseLogger(ABC): 7 | def __init__(self) -> None: 8 | super().__init__() 9 | 10 | @staticmethod 11 | def time(func): 12 | def decorated(*args, **kwargs): 13 | start_time = time() 14 | out = func(*args, **kwargs) 15 | end_time = time() 16 | logging.info("aggregate time cost: %d" % (end_time - start_time)) 17 | return out 18 | 19 | return decorated 20 | 21 | @abstractmethod 22 | def log(*args, **kwargs): 23 | pass 24 | 25 | @abstractmethod 26 | def log_gradients(*args, **kwargs): 27 | pass 28 | 29 | @abstractmethod 30 | def add_scalar(*args, **kwargs): 31 | pass 32 | 33 | @abstractmethod 34 | def add_histogram(*args, **kwargs): 35 | pass 36 | 37 | @abstractmethod 38 | def add_graph(*args, **kwargs): 39 | pass 40 | 41 | 42 | try: 43 | from torch.utils.tensorboard import SummaryWriter 44 | 45 | class TBLogger(SummaryWriter, BaseLogger): 46 | def __init__(self, log_dir, comment="", max_queue=10): 47 | super().__init__(log_dir=log_dir, 48 | comment=comment, 49 | max_queue=max_queue) 50 | 51 | def log(self, *args, **kwargs): 52 | print(*args, **kwargs) 53 | 54 | def log_gradients(self, model, step, to_normalize=True): 55 | for name, param in model.named_parameters(): 56 | if to_normalize: 57 | grad = param.grad.norm() 58 | self.add_scalar("grads/"+name, grad, global_step=step) 59 | else: 60 | grad = param.grad 61 | self.add_histogram("grads/"+name, grad, global_step=step) 62 | 63 | except ImportError: 64 | UserWarning("Tensorboard not installed. No Tensorboard logging.") 65 | 66 | try: 67 | import neptune 68 | 69 | class NeptuneLogger(BaseLogger): 70 | def __init__(self, log_dir, comment="", max_queue=10): 71 | super().__init__() 72 | 73 | def log(self, *args, **kwargs): 74 | print(*args, **kwargs) 75 | 76 | def log_gradients(self, model, step, to_normalize=True): 77 | for name, param in model.named_parameters(): 78 | if to_normalize: 79 | grad = param.grad.norm() 80 | self.add_scalar("grads/"+name, grad, global_step=step) 81 | else: 82 | grad = param.grad 83 | self.add_histogram("grads/"+name, grad, global_step=step) 84 | except ImportError: 85 | UserWarning("Neptune not installed. No Neptune logging.") 86 | 87 | 88 | class NoOpLogger(BaseLogger): 89 | def __init__(self) -> None: 90 | super().__init__() 91 | 92 | def log(*args, **kwargs): 93 | pass 94 | 95 | def log_gradients(*args, **kwargs): 96 | pass 97 | 98 | def add_scalar(*args, **kwargs): 99 | pass 100 | 101 | def add_histogram(*args, **kwargs): 102 | pass 103 | 104 | def add_graph(*args, **kwargs): 105 | pass 106 | -------------------------------------------------------------------------------- /src/utilities/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | import attr 4 | from typing import Dict 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class RandomState: 10 | def __init__(self): 11 | self.random_mod_state = random.getstate() 12 | self.np_state = np.random.get_state() 13 | self.torch_cpu_state = torch.get_rng_state() 14 | self.torch_gpu_states = [ 15 | torch.cuda.get_rng_state(d) 16 | for d in range(torch.cuda.device_count()) 17 | ] 18 | 19 | def restore(self): 20 | random.setstate(self.random_mod_state) 21 | np.random.set_state(self.np_state) 22 | torch.set_rng_state(self.torch_cpu_state) 23 | for d, state in enumerate(self.torch_gpu_states): 24 | torch.cuda.set_rng_state(state, d) 25 | 26 | 27 | class RandomContext: 28 | '''Save and restore state of PyTorch, NumPy, Python RNGs.''' 29 | 30 | def __init__(self, seed=None): 31 | outside_state = RandomState() 32 | 33 | random.seed(seed) 34 | np.random.seed(seed) 35 | if seed is None: 36 | torch.manual_seed(random.randint(-sys.maxsize - 1, sys.maxsize)) 37 | else: 38 | torch.manual_seed(seed) 39 | # torch.cuda.manual_seed_all is called by torch.manual_seed 40 | self.inside_state = RandomState() 41 | 42 | outside_state.restore() 43 | 44 | self._active = False 45 | 46 | def __enter__(self): 47 | if self._active: 48 | raise Exception('RandomContext can be active only once') 49 | 50 | # Save current state of RNG 51 | self.outside_state = RandomState() 52 | # Restore saved state of RNG for this context 53 | self.inside_state.restore() 54 | self._active = True 55 | 56 | def __exit__(self, exception_type, exception_value, traceback): 57 | # Save current state of RNG 58 | self.inside_state = RandomState() 59 | # Restore state of RNG saved in __enter__ 60 | self.outside_state.restore() 61 | self.outside_state = None 62 | 63 | self._active = False 64 | 65 | 66 | @attr.s 67 | class RandomizationConfig: 68 | # Seed for RNG used in shuffling the training data. 69 | data_seed = attr.ib(default=None) 70 | # Seed for RNG used in initializing the model. 71 | init_seed = attr.ib(default=None) 72 | # Seed for RNG used in computing the model's training loss. 73 | # Only relevant with internal randomness in the model, e.g. with dropout. 74 | model_seed = attr.ib(default=None) 75 | 76 | 77 | class Reproducible(object): 78 | def __init__(self, config: Dict) -> None: 79 | self.data_random = RandomContext( 80 | config["data_seed"]) 81 | self.model_random = RandomContext( 82 | config["model_seed"]) 83 | self.init_random = RandomContext( 84 | config["init_seed"]) 85 | -------------------------------------------------------------------------------- /src/utilities/registry.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The registry class makes it easy and quick to experiment with 3 | different algorithms, model architectures and hyperparameters. 4 | We only need to decorate the class definitions with registry.load 5 | and create a yaml configuration file of all the arguments to pass. 6 | Later, if we want to change any parameter (eg. number of hidden layers, 7 | learning rate, or number of clients per round), we need not change the 8 | code but only change the parameters in yaml configuration file. 9 | for detailed explaination on the use of registry, see: 10 | github.com/NimbleEdge/EnvisEdge/blob/main/docs/Tutorial-Part-2-starting_with_nimbleedge.md 11 | ''' 12 | 13 | import collections 14 | import collections.abc 15 | import inspect 16 | import sys 17 | 18 | # a defaultdict provides default values for non-existent keys. 19 | LOOKUP_DICT = collections.defaultdict(dict) 20 | 21 | 22 | def load(kind, name): 23 | ''' 24 | A decorator to record callable object definitions 25 | for models,trainers,workers etc. 26 | Arguments 27 | ---------- 28 | kind: str 29 | Key to store in dictionary, used to specify the 30 | kind of object (eg. model, trainer). 31 | name: str 32 | Sub-key under kind key, used to specify name of 33 | of the object definition. 34 | Returns 35 | ---------- 36 | callable: 37 | Decorator function to store object definition. 38 | Examples 39 | ---------- 40 | >>> @registry.load('model', 'dlrm') 41 | ... class DLRM_Net(nn.Module): # This class definition gets recorded 42 | ... def __init__(self, arg): 43 | ... self.arg = arg 44 | ''' 45 | 46 | assert kind != "class_map", "reserved keyword for kind \"class_map\"" 47 | registry = LOOKUP_DICT[kind] 48 | class_ref = LOOKUP_DICT["class_map"] 49 | 50 | def decorator(obj): 51 | if name in registry: 52 | raise LookupError('{} already present'.format(name, kind)) 53 | registry[name] = obj 54 | class_ref[obj.__module__ + "." + obj.__name__] = obj 55 | return obj 56 | return decorator 57 | 58 | 59 | def lookup(kind, name): 60 | ''' 61 | Returns the callable object definition stored in registry. 62 | Arguments 63 | ---------- 64 | kind: str 65 | Key to search in dictionary of registry. 66 | name: str 67 | Sub-key to search under kind key in dictionary 68 | of registry. 69 | Returns 70 | ---------- 71 | callable: 72 | Object definition stored in registry under key kind 73 | and sub-key name. 74 | Examples 75 | ---------- 76 | >>> @registry.load('model', 'dlrm') 77 | ... class DLRM_Net(nn.Module): # This class definition gets recorded 78 | ... def __init__(self, arg): 79 | ... self.arg = arg 80 | >>> model = lookup('model', 'dlrm') # loads model class from registry 81 | >>> model # model is a DLRM_Net object 82 | __main__.DLRM_Net 83 | ''' 84 | 85 | # check if 'name' argument is a dictionary. 86 | # if yes, load the value under key 'name'. 87 | if isinstance(name, collections.abc.Mapping): 88 | name = name['name'] 89 | 90 | if kind not in LOOKUP_DICT: 91 | raise KeyError('Nothing registered under "{}"'.format(kind)) 92 | return LOOKUP_DICT[kind][name] 93 | 94 | 95 | def construct(kind, config, unused_keys=(), **kwargs): 96 | ''' 97 | Returns an object instance by loading definition from registry, 98 | and arguments from configuration file. 99 | Arguments 100 | ---------- 101 | kind: str 102 | Key to search in dictionary of registry. 103 | config: dict 104 | Configuration dictionary loaded from yaml file 105 | unused_keys: tuple 106 | Keys for values that are not passed as arguments to 107 | insantiate the object but are still present in config. 108 | **kwargs: dict, optional 109 | Extra arguments to pass. 110 | Returns 111 | ---------- 112 | object: 113 | Constructed object using the parameters passed in config and \**kwargs. 114 | Examples 115 | ---------- 116 | >>> @registry.load('model', 'dlrm') 117 | ... class DLRM_Net(nn.Module): # This class definition gets recorded 118 | ... def __init__(self, arg): 119 | ... self.arg = arg 120 | >>> model = construct('model', 'drlm', (), arg = 5) 121 | >>> model.arg # model is a DLRM_Net object with arg = 5 122 | 5 123 | ''' 124 | 125 | # check if 'config' argument is a string, 126 | # if yes, make it a dictionary. 127 | if isinstance(config, str): 128 | config = {'name': config} 129 | return instantiate( 130 | lookup(kind, config), 131 | config, 132 | unused_keys + ('name',), 133 | **kwargs) 134 | 135 | 136 | def instantiate(callable, config, unused_keys=(), **kwargs): 137 | ''' 138 | Instantiates an object after verifying the parameters. 139 | Arguments 140 | ---------- 141 | callable: callable 142 | Definition of object to be instantiated. 143 | config: dict 144 | Arguments to construct the object. 145 | unused_keys: tuple 146 | Keys for values that are not passed as arguments to 147 | insantiate the object but are still present in config. 148 | **kwargs: dict, optional 149 | Extra arguments to pass. 150 | Returns 151 | ---------- 152 | object: 153 | Instantiated object by the parameters passed in config and \**kwargs. 154 | Examples 155 | ---------- 156 | >>> @registry.load('model', 'dlrm') 157 | ... class DLRM_Net(nn.Module): # This class definition gets recorded 158 | ... def __init__(self, arg): 159 | ... self.arg = arg 160 | >>> config = {'name': 'dlrm', 'arg': 5} # loaded from a yaml config file 161 | >>> call = lookup('model', 'dlrm') # Loads the class definition 162 | >>> model = instantiate(call, config, ('name')) 163 | >>> model.arg # model is a DRLM_Net object with arg = 5 164 | 5 165 | ''' 166 | 167 | # merge config arguments and kwargs in a single dictionary. 168 | merged = {**config, **kwargs} 169 | 170 | # check if callable has valid parameters. 171 | signature = inspect.signature(callable) 172 | for name, param in signature.parameters.items(): 173 | if param.kind in (inspect.Parameter.POSITIONAL_ONLY, 174 | inspect.Parameter.VAR_POSITIONAL): 175 | raise ValueError('Unsupported kind for param {}: {}'.format( 176 | name, param.kind)) 177 | 178 | if any(param.kind == inspect.Parameter.VAR_KEYWORD 179 | for param in signature.parameters.values()): 180 | return callable(**merged) 181 | 182 | # check and warn if config has unneccassary arguments that 183 | # callable does not require and are not mentioned in unused_keys. 184 | missing = {} 185 | for key in list(merged.keys()): 186 | if key not in signature.parameters: 187 | if key not in unused_keys: 188 | missing[key] = merged[key] 189 | merged.pop(key) 190 | if missing: 191 | print('WARNING {}: superfluous {}'.format( 192 | callable, missing), file=sys.stderr) 193 | return callable(**merged) 194 | 195 | -------------------------------------------------------------------------------- /src/utilities/saver.py: -------------------------------------------------------------------------------- 1 | """Tools to save/restore model from checkpoints.""" 2 | 3 | import shutil 4 | import os 5 | import re 6 | 7 | import torch 8 | 9 | CHECKPOINT_PATTERN = re.compile('^model_checkpoint-(\d+)$') 10 | 11 | 12 | class ArgsDict(dict): 13 | 14 | def __init__(self, **kwargs): 15 | super(ArgsDict, self).__init__() 16 | for key, value in kwargs.items(): 17 | self[key] = value 18 | self.__dict__ = self 19 | 20 | 21 | def create_link(original, link_name): 22 | if os.path.islink(link_name): 23 | os.unlink(link_name) 24 | try: 25 | os.symlink(os.path.basename(original), link_name) 26 | except OSError: 27 | shutil.copy2(original, link_name) 28 | 29 | 30 | def load_checkpoint(model, 31 | optimizer, 32 | model_dir, 33 | map_location=None, 34 | step=None): 35 | path = os.path.join(model_dir, 'model_checkpoint') 36 | if step is not None: 37 | path += '-{:08d}'.format(step) 38 | if os.path.exists(path): 39 | print("Loading model from %s" % path) 40 | checkpoint = torch.load(path, map_location=map_location) 41 | model.load_state_dict(checkpoint['model'], strict=False) 42 | optimizer.load_state_dict(checkpoint['optimizer']) 43 | return checkpoint.get('step', 0), checkpoint.get('epoch', 0) 44 | return 0, 0 45 | 46 | 47 | def load_and_map_checkpoint(model, model_dir, remap): 48 | path = os.path.join(model_dir, 'model_checkpoint') 49 | print("Loading parameters %s from %s" % (remap.keys(), model_dir)) 50 | checkpoint = torch.load(path) 51 | new_state_dict = model.state_dict() 52 | for name, value in remap.items(): 53 | # TODO: smarter mapping. 54 | new_state_dict[name] = checkpoint['model'][value] 55 | model.load_state_dict(new_state_dict) 56 | 57 | 58 | def save_checkpoint(model, 59 | optimizer, 60 | step, 61 | epoch, 62 | model_dir, 63 | is_best, 64 | ignore=[], 65 | keep_every_n=10000000): 66 | if not os.path.exists(model_dir): 67 | os.makedirs(model_dir) 68 | path_without_step = os.path.join(model_dir, 'model_checkpoint') 69 | step_padded = format(step, '08d') 70 | state_dict = model.state_dict() 71 | if ignore: 72 | for key in state_dict.keys(): 73 | for item in ignore: 74 | if key.startswith(item): 75 | state_dict.pop(key) 76 | path_with_step = '{}-{}'.format(path_without_step, step_padded) 77 | torch.save({ 78 | 'model': state_dict, 79 | 'optimizer': optimizer.state_dict(), 80 | 'epoch': epoch, 81 | 'step': step 82 | }, path_with_step) 83 | create_link(path_with_step, path_without_step) 84 | create_link(path_with_step, os.path.join(model_dir, 'best_checkpoint')) 85 | 86 | # Cull old checkpoints. 87 | if keep_every_n is not None: 88 | all_checkpoints = [] 89 | for name in os.listdir(model_dir): 90 | m = CHECKPOINT_PATTERN.match(name) 91 | if m is None or name == os.path.basename(path_with_step): 92 | continue 93 | checkpoint_step = int(m.group(1)) 94 | all_checkpoints.append((checkpoint_step, name)) 95 | all_checkpoints.sort() 96 | 97 | last_step = float('-inf') 98 | for checkpoint_step, name in all_checkpoints: 99 | if checkpoint_step - last_step >= keep_every_n: 100 | last_step = checkpoint_step 101 | continue 102 | os.unlink(os.path.join(model_dir, name)) 103 | 104 | 105 | class Saver(object): 106 | """Class to manage save and restore for the model and optimizer.""" 107 | 108 | def __init__(self, model, optimizer, keep_every_n=None): 109 | self._model = model 110 | self._optimizer = optimizer 111 | self._keep_every_n = keep_every_n 112 | 113 | def restore(self, model_dir=None, map_location=None, step=None): 114 | """Restores model and optimizer from given directory. 115 | Returns 116 | Last training step for the model restored. 117 | """ 118 | if model_dir is None: 119 | return 0, 0 120 | last_step, epoch = load_checkpoint( 121 | self._model, self._optimizer, model_dir, map_location, step) 122 | return last_step, epoch 123 | 124 | def save(self, model_dir, step, epoch, is_best=False): 125 | """Saves model and optimizer to given directory. 126 | Args: 127 | model_dir: Model directory to save. If None ignore. 128 | step: Current training step. 129 | """ 130 | if model_dir is None: 131 | return 132 | save_checkpoint(self._model, self._optimizer, step, epoch, model_dir, 133 | keep_every_n=self._keep_every_n, is_best=is_best) 134 | 135 | def restore_part(self, other_model_dir, remap): 136 | """Restores part of the model from other directory. 137 | Useful to initialize part of the model with another pretrained model. 138 | Args: 139 | other_model_dir: Model directory to load from. 140 | remap: dict, remapping current parameters to the other model's. 141 | """ 142 | load_and_map_checkpoint(self._model, other_model_dir, remap) -------------------------------------------------------------------------------- /train_predictors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Training script for sparsity predictors in LlamaSkipConnection models. 4 | 5 | This script trains the LoRA-based sparsity predictors to predict which neurons 6 | will be most important based on ground truth activations from standard LLaMA. 7 | 8 | The script uses the C4 (Colossal Clean Crawled Corpus) dataset by default, which 9 | is the same dataset used to train LLaMA models. C4 provides high-quality English 10 | text data that matches the training distribution of the original LLaMA models, 11 | making it ideal for training sparsity predictors. 12 | 13 | C4 Dataset: 14 | - 750GB of English-language text from Common Crawl 15 | - Cleaned and filtered web content 16 | - Used to train Google's T5/LaMDA and Meta's LLaMA models 17 | - Available via HuggingFace datasets as 'allenai/c4' 18 | 19 | Usage: 20 | python train_predictors.py \ 21 | --config configs/llama_skip_causal_3b_predictor_training.json \ 22 | --output_dir ./trained_predictors \ 23 | --num_samples 50000 \ 24 | --batch_size 8 \ 25 | --num_epochs 5 \ 26 | --use_wandb 27 | """ 28 | 29 | import argparse 30 | import json 31 | import logging 32 | import os 33 | import time 34 | from typing import Dict, List, Optional, Tuple 35 | 36 | import torch 37 | import torch.nn.functional as F 38 | from torch.utils.data import DataLoader, Dataset 39 | from transformers import ( 40 | AutoTokenizer, 41 | AutoConfig, 42 | AutoModelForCausalLM, 43 | get_linear_schedule_with_warmup, 44 | set_seed 45 | ) 46 | from datasets import load_dataset 47 | import wandb 48 | from tqdm import tqdm 49 | 50 | from src.models.llama.modelling_llama_skip import LlamaSkipConnectionForCausalLM 51 | from src.models.llama.configuration_llama_skip import LlamaSkipConnectionConfig 52 | 53 | # Setup logging 54 | logging.basicConfig(level=logging.INFO) 55 | logger = logging.getLogger(__name__) 56 | 57 | 58 | class TextDataset(Dataset): 59 | """Dataset for training sparsity predictors.""" 60 | 61 | def __init__(self, texts: List[str], tokenizer, max_length: int = 512): 62 | self.texts = texts 63 | self.tokenizer = tokenizer 64 | self.max_length = max_length 65 | 66 | def __len__(self): 67 | return len(self.texts) 68 | 69 | def __getitem__(self, idx): 70 | text = self.texts[idx] 71 | encoding = self.tokenizer( 72 | text, 73 | truncation=True, 74 | padding='max_length', 75 | max_length=self.max_length, 76 | return_tensors='pt' 77 | ) 78 | return { 79 | 'input_ids': encoding['input_ids'].squeeze(), 80 | 'attention_mask': encoding['attention_mask'].squeeze() 81 | } 82 | 83 | 84 | def load_training_data(dataset_name: str = "allenai/c4", 85 | dataset_config: str = "realnewslike", 86 | num_samples: int = 10000, 87 | max_length: int = 512) -> Tuple[List[str], List[str]]: 88 | """Load and prepare training data.""" 89 | logger.info(f"Loading dataset: {dataset_name}/{dataset_config}") 90 | 91 | if dataset_name == "allenai/c4": 92 | # Load C4 dataset with streaming for efficiency 93 | dataset = load_dataset(dataset_name, dataset_config, split="train", streaming=True) 94 | 95 | # Extract text samples from streaming dataset 96 | train_texts = [] 97 | val_texts = [] 98 | 99 | logger.info("Extracting samples from C4 dataset...") 100 | for i, sample in enumerate(dataset): 101 | if i >= num_samples + 1000: # Extra samples for validation 102 | break 103 | 104 | text = sample['text'] 105 | if len(text.strip()) > 50: # Filter out very short texts 106 | if i < num_samples: 107 | train_texts.append(text) 108 | else: 109 | val_texts.append(text) 110 | 111 | # Ensure we have validation samples 112 | if not val_texts: 113 | val_texts = train_texts[-1000:] # Use last 1000 as validation 114 | train_texts = train_texts[:-1000] 115 | 116 | else: 117 | # Original logic for other datasets like WikiText 118 | dataset = load_dataset(dataset_name, dataset_config) 119 | 120 | # Filter out empty texts and combine train/validation 121 | train_texts = [text for text in dataset['train']['text'] if text.strip()] 122 | val_texts = [text for text in dataset['validation']['text'] if text.strip()] 123 | 124 | # Limit number of samples 125 | train_texts = train_texts[:num_samples] 126 | val_texts = val_texts[:min(1000, len(val_texts))] 127 | 128 | logger.info(f"Loaded {len(train_texts)} training samples, {len(val_texts)} validation samples") 129 | return train_texts, val_texts 130 | 131 | 132 | def evaluate_predictor_accuracy(model: LlamaSkipConnectionForCausalLM, 133 | dataloader: DataLoader, 134 | device: torch.device, 135 | max_batches: int = 50) -> Dict[str, float]: 136 | """Evaluate predictor accuracy against ground truth.""" 137 | model.eval() # Set to evaluation mode 138 | 139 | total_accuracy = 0.0 140 | total_precision = 0.0 141 | total_recall = 0.0 142 | total_f1 = 0.0 143 | num_batches = 0 144 | 145 | with torch.no_grad(): 146 | for batch_idx, batch in enumerate(dataloader): 147 | if batch_idx >= max_batches: 148 | break 149 | 150 | input_ids = batch['input_ids'].to(device) 151 | attention_mask = batch['attention_mask'].to(device) 152 | 153 | # Forward pass to collect predictor scores and ground truth 154 | outputs = model(input_ids=input_ids, attention_mask=attention_mask) 155 | 156 | # Calculate metrics for each layer 157 | layer_accuracies = [] 158 | layer_precisions = [] 159 | layer_recalls = [] 160 | layer_f1s = [] 161 | 162 | for layer in model.model.layers: 163 | if hasattr(layer, 'training_mode') and layer.training_mode: 164 | # Get last hidden states for this evaluation 165 | hidden_states = outputs.hidden_states[-1] if outputs.hidden_states else None 166 | if hidden_states is not None: 167 | # Get predictor scores 168 | hidden_reshaped = hidden_states.view(-1, hidden_states.shape[-1]) 169 | pred_scores = layer.mlp_lora_proj(hidden_reshaped) 170 | 171 | # Get ground truth activations 172 | gt_activations = layer.get_ground_truth_activations(hidden_reshaped) 173 | 174 | # Create binary masks 175 | k = int(layer.sparsity * pred_scores.shape[-1]) 176 | _, gt_indices = torch.topk(torch.abs(gt_activations), k, dim=-1) 177 | _, pred_indices = torch.topk(pred_scores, k, dim=-1) 178 | 179 | # Calculate metrics 180 | gt_mask = torch.zeros_like(pred_scores, dtype=torch.bool) 181 | pred_mask = torch.zeros_like(pred_scores, dtype=torch.bool) 182 | gt_mask.scatter_(1, gt_indices, True) 183 | pred_mask.scatter_(1, pred_indices, True) 184 | 185 | # Accuracy 186 | correct = (gt_mask == pred_mask).float().mean() 187 | layer_accuracies.append(correct.item()) 188 | 189 | # Precision, Recall, F1 190 | tp = (gt_mask & pred_mask).sum().float() 191 | fp = (~gt_mask & pred_mask).sum().float() 192 | fn = (gt_mask & ~pred_mask).sum().float() 193 | 194 | precision = tp / (tp + fp + 1e-8) 195 | recall = tp / (tp + fn + 1e-8) 196 | f1 = 2 * precision * recall / (precision + recall + 1e-8) 197 | 198 | layer_precisions.append(precision.item()) 199 | layer_recalls.append(recall.item()) 200 | layer_f1s.append(f1.item()) 201 | 202 | if layer_accuracies: 203 | total_accuracy += sum(layer_accuracies) / len(layer_accuracies) 204 | total_precision += sum(layer_precisions) / len(layer_precisions) 205 | total_recall += sum(layer_recalls) / len(layer_recalls) 206 | total_f1 += sum(layer_f1s) / len(layer_f1s) 207 | num_batches += 1 208 | 209 | model.train() # Switch back to training mode 210 | 211 | if num_batches == 0: 212 | return {"accuracy": 0.0, "precision": 0.0, "recall": 0.0, "f1": 0.0} 213 | 214 | return { 215 | "accuracy": total_accuracy / num_batches, 216 | "precision": total_precision / num_batches, 217 | "recall": total_recall / num_batches, 218 | "f1": total_f1 / num_batches 219 | } 220 | 221 | 222 | def train_predictors( 223 | model: LlamaSkipConnectionForCausalLM, 224 | train_dataloader: DataLoader, 225 | val_dataloader: DataLoader, 226 | num_epochs: int, 227 | learning_rate: float, 228 | device: torch.device, 229 | save_dir: str, 230 | eval_steps: int = 500, 231 | save_steps: int = 1000, 232 | use_wandb: bool = False 233 | ) -> None: 234 | """Train the sparsity predictors using PyTorch's built-in training mode.""" 235 | 236 | # Setup training mode using PyTorch's built-in methods 237 | model.train() # Enable training mode 238 | model.freeze_non_predictor_parameters() 239 | 240 | # Setup optimizer 241 | predictor_params = model.get_predictor_parameters() 242 | optimizer = torch.optim.AdamW(predictor_params, lr=learning_rate, weight_decay=0.01) 243 | 244 | # Setup scheduler 245 | total_steps = len(train_dataloader) * num_epochs 246 | scheduler = get_linear_schedule_with_warmup( 247 | optimizer, 248 | num_warmup_steps=int(0.1 * total_steps), 249 | num_training_steps=total_steps 250 | ) 251 | 252 | # Training loop 253 | global_step = 0 254 | best_f1 = 0.0 255 | 256 | logger.info(f"Starting training for {num_epochs} epochs") 257 | logger.info(f"Total training steps: {total_steps}") 258 | logger.info(f"Number of predictor parameters: {sum(p.numel() for p in predictor_params)}") 259 | 260 | for epoch in range(num_epochs): 261 | epoch_loss = 0.0 262 | epoch_predictor_loss = 0.0 263 | num_batches = 0 264 | 265 | progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}") 266 | 267 | for batch in progress_bar: 268 | input_ids = batch['input_ids'].to(device) 269 | attention_mask = batch['attention_mask'].to(device) 270 | 271 | # Forward pass 272 | outputs = model(input_ids=input_ids, attention_mask=attention_mask) 273 | loss = outputs.loss 274 | 275 | # Backward pass 276 | optimizer.zero_grad() 277 | loss.backward() 278 | torch.nn.utils.clip_grad_norm_(predictor_params, max_norm=1.0) 279 | optimizer.step() 280 | scheduler.step() 281 | 282 | # Update metrics 283 | epoch_loss += loss.item() 284 | epoch_predictor_loss += loss.item() 285 | num_batches += 1 286 | global_step += 1 287 | 288 | # Update progress bar 289 | progress_bar.set_postfix({ 290 | 'loss': f"{loss.item():.4f}", 291 | 'lr': f"{scheduler.get_last_lr()[0]:.2e}" 292 | }) 293 | 294 | # Evaluation 295 | if global_step % eval_steps == 0: 296 | logger.info(f"Evaluating at step {global_step}") 297 | try: 298 | eval_metrics = evaluate_predictor_accuracy(model, val_dataloader, device) 299 | logger.info(f"Evaluation metrics: {eval_metrics}") 300 | 301 | if use_wandb: 302 | wandb.log({ 303 | "eval/accuracy": eval_metrics["accuracy"], 304 | "eval/precision": eval_metrics["precision"], 305 | "eval/recall": eval_metrics["recall"], 306 | "eval/f1": eval_metrics["f1"], 307 | "step": global_step 308 | }) 309 | 310 | # Save best model 311 | if eval_metrics["f1"] > best_f1: 312 | best_f1 = eval_metrics["f1"] 313 | best_model_path = os.path.join(save_dir, "best_predictors") 314 | try: 315 | model.save_pretrained(best_model_path) 316 | logger.info(f"Saved best model with F1: {best_f1:.4f}") 317 | except Exception as e: 318 | logger.warning(f"Failed to save best model: {e}") 319 | 320 | # Model is already back in training mode from evaluate_predictor_accuracy 321 | except Exception as e: 322 | logger.warning(f"Evaluation failed at step {global_step}: {e}") 323 | model.train() # Ensure we're back in training mode 324 | 325 | # Save checkpoint 326 | if global_step % save_steps == 0: 327 | checkpoint_path = os.path.join(save_dir, f"checkpoint-{global_step}") 328 | model.save_pretrained(checkpoint_path) 329 | logger.info(f"Saved checkpoint at step {global_step}") 330 | 331 | # Log training metrics 332 | if use_wandb and global_step % 100 == 0: 333 | wandb.log({ 334 | "train/loss": loss.item(), 335 | "train/learning_rate": scheduler.get_last_lr()[0], 336 | "step": global_step 337 | }) 338 | 339 | # End of epoch logging 340 | avg_loss = epoch_loss / num_batches 341 | logger.info(f"Epoch {epoch+1} completed. Average loss: {avg_loss:.4f}") 342 | 343 | if use_wandb: 344 | wandb.log({ 345 | "train/epoch_loss": avg_loss, 346 | "epoch": epoch + 1 347 | }) 348 | 349 | # Final save 350 | final_model_path = os.path.join(save_dir, "final_predictors") 351 | model.save_pretrained(final_model_path) 352 | logger.info("Training completed!") 353 | 354 | 355 | def main(): 356 | parser = argparse.ArgumentParser(description="Train sparsity predictors for LlamaSkipConnection") 357 | parser.add_argument("--config", type=str, required=True, help="Path to model config file") 358 | parser.add_argument("--output_dir", type=str, required=True, help="Output directory for trained models") 359 | parser.add_argument("--dataset", type=str, default="allenai/c4", help="Dataset name (default: allenai/c4)") 360 | parser.add_argument("--dataset_config", type=str, default="realnewslike", 361 | help="Dataset configuration (default: realnewslike for C4)") 362 | parser.add_argument("--num_samples", type=int, default=10000, 363 | help="Number of training samples (default: 10000)") 364 | parser.add_argument("--max_length", type=int, default=512, help="Maximum sequence length") 365 | parser.add_argument("--batch_size", type=int, default=4, help="Training batch size") 366 | parser.add_argument("--num_epochs", type=int, default=3, help="Number of training epochs") 367 | parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") 368 | parser.add_argument("--eval_steps", type=int, default=500, help="Evaluation frequency") 369 | parser.add_argument("--save_steps", type=int, default=1000, help="Save frequency") 370 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 371 | parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases logging") 372 | parser.add_argument("--wandb_project", type=str, default="llama-skip-predictors", help="W&B project name") 373 | parser.add_argument("--wandb_entity", type=str, default="llama-skip-predictors", help="W&B entity name") 374 | parser.add_argument("--device", type=str, default="auto", help="Device to use (auto, cpu, cuda)") 375 | 376 | args = parser.parse_args() 377 | 378 | # Set seed 379 | set_seed(args.seed) 380 | 381 | # Setup device 382 | if args.device == "auto": 383 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 384 | else: 385 | device = torch.device(args.device) 386 | 387 | logger.info(f"Using device: {device}") 388 | 389 | # Setup output directory 390 | os.makedirs(args.output_dir, exist_ok=True) 391 | 392 | # Initialize wandb 393 | if args.use_wandb: 394 | wandb.init( 395 | project=args.wandb_project, 396 | entity=args.wandb_entity, 397 | config=vars(args), 398 | name=f"predictor-training-{int(time.time())}" 399 | ) 400 | 401 | # Load model configuration 402 | config = LlamaSkipConnectionConfig.from_json_file(args.config) 403 | checkpoint = config._name_or_path 404 | 405 | # Register custom models 406 | AutoConfig.register("llama-skip", LlamaSkipConnectionConfig) 407 | AutoModelForCausalLM.register(LlamaSkipConnectionConfig, LlamaSkipConnectionForCausalLM) 408 | 409 | # Load tokenizer 410 | tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) 411 | tokenizer.pad_token = tokenizer.eos_token 412 | 413 | # Load model 414 | logger.info("Loading model...") 415 | model = LlamaSkipConnectionForCausalLM.from_pretrained(checkpoint, config=config) 416 | model = model.to(device) 417 | 418 | # Load training data 419 | train_texts, val_texts = load_training_data( 420 | args.dataset, args.dataset_config, args.num_samples, args.max_length 421 | ) 422 | 423 | # Create datasets and dataloaders 424 | train_dataset = TextDataset(train_texts, tokenizer, args.max_length) 425 | val_dataset = TextDataset(val_texts, tokenizer, args.max_length) 426 | 427 | train_dataloader = DataLoader( 428 | train_dataset, 429 | batch_size=args.batch_size, 430 | shuffle=True, 431 | num_workers=4, 432 | pin_memory=True 433 | ) 434 | val_dataloader = DataLoader( 435 | val_dataset, 436 | batch_size=args.batch_size, 437 | shuffle=False, 438 | num_workers=4, 439 | pin_memory=True 440 | ) 441 | 442 | # Train predictors 443 | train_predictors( 444 | model=model, 445 | train_dataloader=train_dataloader, 446 | val_dataloader=val_dataloader, 447 | num_epochs=args.num_epochs, 448 | learning_rate=args.learning_rate, 449 | device=device, 450 | save_dir=args.output_dir, 451 | eval_steps=args.eval_steps, 452 | save_steps=args.save_steps, 453 | use_wandb=args.use_wandb 454 | ) 455 | 456 | if args.use_wandb: 457 | wandb.finish() 458 | 459 | 460 | if __name__ == "__main__": 461 | main() --------------------------------------------------------------------------------