├── .gitignore ├── .pre-commit-config.yaml ├── ACKNOWLEDGEMENTS.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets └── cce_figure.png ├── benchmark ├── __main__.py ├── data │ ├── __init__.py │ ├── data.py │ ├── models.py │ └── randn.py └── memory.py ├── cut_cross_entropy ├── __init__.py ├── cce.py ├── cce_backward.py ├── cce_lse_forward.py ├── cce_utils.py ├── constants.py ├── doc.py ├── indexed_dot.py ├── linear_cross_entropy.py ├── tl_autotune.py ├── tl_utils.py ├── torch_compile.py ├── transformers │ ├── __init__.py │ ├── gemma2.py │ ├── llama.py │ ├── mistral.py │ ├── patch.py │ ├── phi3.py │ ├── qwen2.py │ └── utils.py ├── utils.py └── vocab_parallel │ ├── __init__.py │ ├── utils.py │ └── vocab_parallel_torch_compile.py ├── pyproject.toml ├── scripts └── train.sh ├── tests ├── test_cce_indexed_dot.py ├── test_cce_loss_backward.py ├── test_cce_loss_forward.py ├── test_cce_lse.py └── test_vocab_parallel.py └── training ├── train.py └── zero3.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | 7 | # Log 8 | *.log 9 | *.log.* 10 | *.jsonl 11 | 12 | # Data 13 | !**/alpaca-data-conversation.json 14 | 15 | # Editor 16 | .idea 17 | *.swp 18 | 19 | # Other 20 | .DS_Store 21 | wandb 22 | output 23 | 24 | checkpoints 25 | ckpts* 26 | 27 | .ipynb_checkpoints 28 | *.ipynb 29 | 30 | # DevContainer 31 | !.devcontainer/* 32 | 33 | # Demo 34 | serve_images/ 35 | 36 | /datasets/ 37 | /data/ 38 | my_scripts/ 39 | 40 | .autoenv*.zsh 41 | 42 | .model_cache/ 43 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: check-merge-conflict 7 | - id: check-yaml 8 | - id: end-of-file-fixer 9 | - repo: https://github.com/astral-sh/ruff-pre-commit 10 | rev: v0.5.5 11 | hooks: 12 | - id: ruff 13 | args: [ --fix, --show-fixes ] 14 | types: [python] 15 | - id: ruff-format 16 | types: [python] 17 | -------------------------------------------------------------------------------- /ACKNOWLEDGEMENTS.md: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | 3 | Portions of this Cut Cross Entropy Software may utilize the following copyrighted 4 | material, the use of which is hereby acknowledged. 5 | 6 | 7 | ------ 8 | 9 | 10 | PyTorch 11 | 12 | From PyTorch: 13 | 14 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 15 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 16 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 17 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 18 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 19 | Copyright (c) 2011-2013 NYU (Clement Farabet) 20 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 21 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 22 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 23 | 24 | From Caffe2: 25 | 26 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 27 | 28 | All contributions by Facebook: 29 | Copyright (c) 2016 Facebook Inc. 30 | 31 | All contributions by Google: 32 | Copyright (c) 2015 Google Inc. 33 | All rights reserved. 34 | 35 | All contributions by Yangqing Jia: 36 | Copyright (c) 2015 Yangqing Jia 37 | All rights reserved. 38 | 39 | All contributions by Kakao Brain: 40 | Copyright 2019-2020 Kakao Brain 41 | 42 | All contributions by Cruise LLC: 43 | Copyright (c) 2022 Cruise LLC. 44 | All rights reserved. 45 | 46 | All contributions by Arm: 47 | Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates 48 | 49 | All contributions from Caffe: 50 | Copyright(c) 2013, 2014, 2015, the respective contributors 51 | All rights reserved. 52 | 53 | All other contributions: 54 | Copyright(c) 2015, 2016 the respective contributors 55 | All rights reserved. 56 | 57 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 58 | copyright over their contributions to Caffe2. The project versioning records 59 | all such contribution and copyright details. If a contributor wants to further 60 | mark their specific copyright on a particular contribution, they should 61 | indicate their copyright solely in the commit message of the change when it is 62 | committed. 63 | 64 | All rights reserved. 65 | 66 | Redistribution and use in source and binary forms, with or without 67 | modification, are permitted provided that the following conditions are met: 68 | 69 | 1. Redistributions of source code must retain the above copyright 70 | notice, this list of conditions and the following disclaimer. 71 | 72 | 2. Redistributions in binary form must reproduce the above copyright 73 | notice, this list of conditions and the following disclaimer in the 74 | documentation and/or other materials provided with the distribution. 75 | 76 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 77 | and IDIAP Research Institute nor the names of its contributors may be 78 | used to endorse or promote products derived from this software without 79 | specific prior written permission. 80 | 81 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 82 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 83 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 84 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 85 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 86 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 87 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 88 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 89 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 90 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 91 | POSSIBILITY OF SUCH DAMAGE. 92 | 93 | 94 | Triton 95 | 96 | /* 97 | * Copyright 2018-2020 Philippe Tillet 98 | * Copyright 2020-2022 OpenAI 99 | * 100 | * Permission is hereby granted, free of charge, to any person obtaining 101 | * a copy of this software and associated documentation files 102 | * (the "Software"), to deal in the Software without restriction, 103 | * including without limitation the rights to use, copy, modify, merge, 104 | * publish, distribute, sublicense, and/or sell copies of the Software, 105 | * and to permit persons to whom the Software is furnished to do so, 106 | * subject to the following conditions: 107 | * 108 | * The above copyright notice and this permission notice shall be 109 | * included in all copies or substantial portions of the Software. 110 | * 111 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 112 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 113 | * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 114 | * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 115 | * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 116 | * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 117 | * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 118 | */ 119 | 120 | 121 | Transformers 122 | 123 | Copyright 2018- The Hugging Face team. All rights reserved. 124 | 125 | Apache License 126 | Version 2.0, January 2004 127 | http://www.apache.org/licenses/ 128 | 129 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 130 | 131 | 1. Definitions. 132 | 133 | "License" shall mean the terms and conditions for use, reproduction, 134 | and distribution as defined by Sections 1 through 9 of this document. 135 | 136 | "Licensor" shall mean the copyright owner or entity authorized by 137 | the copyright owner that is granting the License. 138 | 139 | "Legal Entity" shall mean the union of the acting entity and all 140 | other entities that control, are controlled by, or are under common 141 | control with that entity. For the purposes of this definition, 142 | "control" means (i) the power, direct or indirect, to cause the 143 | direction or management of such entity, whether by contract or 144 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 145 | outstanding shares, or (iii) beneficial ownership of such entity. 146 | 147 | "You" (or "Your") shall mean an individual or Legal Entity 148 | exercising permissions granted by this License. 149 | 150 | "Source" form shall mean the preferred form for making modifications, 151 | including but not limited to software source code, documentation 152 | source, and configuration files. 153 | 154 | "Object" form shall mean any form resulting from mechanical 155 | transformation or translation of a Source form, including but 156 | not limited to compiled object code, generated documentation, 157 | and conversions to other media types. 158 | 159 | "Work" shall mean the work of authorship, whether in Source or 160 | Object form, made available under the License, as indicated by a 161 | copyright notice that is included in or attached to the work 162 | (an example is provided in the Appendix below). 163 | 164 | "Derivative Works" shall mean any work, whether in Source or Object 165 | form, that is based on (or derived from) the Work and for which the 166 | editorial revisions, annotations, elaborations, or other modifications 167 | represent, as a whole, an original work of authorship. For the purposes 168 | of this License, Derivative Works shall not include works that remain 169 | separable from, or merely link (or bind by name) to the interfaces of, 170 | the Work and Derivative Works thereof. 171 | 172 | "Contribution" shall mean any work of authorship, including 173 | the original version of the Work and any modifications or additions 174 | to that Work or Derivative Works thereof, that is intentionally 175 | submitted to Licensor for inclusion in the Work by the copyright owner 176 | or by an individual or Legal Entity authorized to submit on behalf of 177 | the copyright owner. For the purposes of this definition, "submitted" 178 | means any form of electronic, verbal, or written communication sent 179 | to the Licensor or its representatives, including but not limited to 180 | communication on electronic mailing lists, source code control systems, 181 | and issue tracking systems that are managed by, or on behalf of, the 182 | Licensor for the purpose of discussing and improving the Work, but 183 | excluding communication that is conspicuously marked or otherwise 184 | designated in writing by the copyright owner as "Not a Contribution." 185 | 186 | "Contributor" shall mean Licensor and any individual or Legal Entity 187 | on behalf of whom a Contribution has been received by Licensor and 188 | subsequently incorporated within the Work. 189 | 190 | 2. Grant of Copyright License. Subject to the terms and conditions of 191 | this License, each Contributor hereby grants to You a perpetual, 192 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 193 | copyright license to reproduce, prepare Derivative Works of, 194 | publicly display, publicly perform, sublicense, and distribute the 195 | Work and such Derivative Works in Source or Object form. 196 | 197 | 3. Grant of Patent License. Subject to the terms and conditions of 198 | this License, each Contributor hereby grants to You a perpetual, 199 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 200 | (except as stated in this section) patent license to make, have made, 201 | use, offer to sell, sell, import, and otherwise transfer the Work, 202 | where such license applies only to those patent claims licensable 203 | by such Contributor that are necessarily infringed by their 204 | Contribution(s) alone or by combination of their Contribution(s) 205 | with the Work to which such Contribution(s) was submitted. If You 206 | institute patent litigation against any entity (including a 207 | cross-claim or counterclaim in a lawsuit) alleging that the Work 208 | or a Contribution incorporated within the Work constitutes direct 209 | or contributory patent infringement, then any patent licenses 210 | granted to You under this License for that Work shall terminate 211 | as of the date such litigation is filed. 212 | 213 | 4. Redistribution. You may reproduce and distribute copies of the 214 | Work or Derivative Works thereof in any medium, with or without 215 | modifications, and in Source or Object form, provided that You 216 | meet the following conditions: 217 | 218 | (a) You must give any other recipients of the Work or 219 | Derivative Works a copy of this License; and 220 | 221 | (b) You must cause any modified files to carry prominent notices 222 | stating that You changed the files; and 223 | 224 | (c) You must retain, in the Source form of any Derivative Works 225 | that You distribute, all copyright, patent, trademark, and 226 | attribution notices from the Source form of the Work, 227 | excluding those notices that do not pertain to any part of 228 | the Derivative Works; and 229 | 230 | (d) If the Work includes a "NOTICE" text file as part of its 231 | distribution, then any Derivative Works that You distribute must 232 | include a readable copy of the attribution notices contained 233 | within such NOTICE file, excluding those notices that do not 234 | pertain to any part of the Derivative Works, in at least one 235 | of the following places: within a NOTICE text file distributed 236 | as part of the Derivative Works; within the Source form or 237 | documentation, if provided along with the Derivative Works; or, 238 | within a display generated by the Derivative Works, if and 239 | wherever such third-party notices normally appear. The contents 240 | of the NOTICE file are for informational purposes only and 241 | do not modify the License. You may add Your own attribution 242 | notices within Derivative Works that You distribute, alongside 243 | or as an addendum to the NOTICE text from the Work, provided 244 | that such additional attribution notices cannot be construed 245 | as modifying the License. 246 | 247 | You may add Your own copyright statement to Your modifications and 248 | may provide additional or different license terms and conditions 249 | for use, reproduction, or distribution of Your modifications, or 250 | for any such Derivative Works as a whole, provided Your use, 251 | reproduction, and distribution of the Work otherwise complies with 252 | the conditions stated in this License. 253 | 254 | 5. Submission of Contributions. Unless You explicitly state otherwise, 255 | any Contribution intentionally submitted for inclusion in the Work 256 | by You to the Licensor shall be under the terms and conditions of 257 | this License, without any additional terms or conditions. 258 | Notwithstanding the above, nothing herein shall supersede or modify 259 | the terms of any separate license agreement you may have executed 260 | with Licensor regarding such Contributions. 261 | 262 | 6. Trademarks. This License does not grant permission to use the trade 263 | names, trademarks, service marks, or product names of the Licensor, 264 | except as required for reasonable and customary use in describing the 265 | origin of the Work and reproducing the content of the NOTICE file. 266 | 267 | 7. Disclaimer of Warranty. Unless required by applicable law or 268 | agreed to in writing, Licensor provides the Work (and each 269 | Contributor provides its Contributions) on an "AS IS" BASIS, 270 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 271 | implied, including, without limitation, any warranties or conditions 272 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 273 | PARTICULAR PURPOSE. You are solely responsible for determining the 274 | appropriateness of using or redistributing the Work and assume any 275 | risks associated with Your exercise of permissions under this License. 276 | 277 | 8. Limitation of Liability. In no event and under no legal theory, 278 | whether in tort (including negligence), contract, or otherwise, 279 | unless required by applicable law (such as deliberate and grossly 280 | negligent acts) or agreed to in writing, shall any Contributor be 281 | liable to You for damages, including any direct, indirect, special, 282 | incidental, or consequential damages of any character arising as a 283 | result of this License or out of the use or inability to use the 284 | Work (including but not limited to damages for loss of goodwill, 285 | work stoppage, computer failure or malfunction, or any and all 286 | other commercial damages or losses), even if such Contributor 287 | has been advised of the possibility of such damages. 288 | 289 | 9. Accepting Warranty or Additional Liability. While redistributing 290 | the Work or Derivative Works thereof, You may choose to offer, 291 | and charge a fee for, acceptance of support, warranty, indemnity, 292 | or other liability obligations and/or rights consistent with this 293 | License. However, in accepting such obligations, You may act only 294 | on Your own behalf and on Your sole responsibility, not on behalf 295 | of any other Contributor, and only if You agree to indemnify, 296 | defend, and hold each Contributor harmless for any liability 297 | incurred by, or claims asserted against, such Contributor by reason 298 | of your accepting any such warranty or additional liability. 299 | 300 | END OF TERMS AND CONDITIONS 301 | 302 | APPENDIX: How to apply the Apache License to your work. 303 | 304 | To apply the Apache License to your work, attach the following 305 | boilerplate notice, with the fields enclosed by brackets "[]" 306 | replaced with your own identifying information. (Don't include 307 | the brackets!) The text should be enclosed in the appropriate 308 | comment syntax for the file format. We also recommend that a 309 | file or class name and description of purpose be included on the 310 | same "printed page" as the copyright notice for easier 311 | identification within third-party archives. 312 | 313 | Copyright [yyyy] [name of copyright owner] 314 | 315 | Licensed under the Apache License, Version 2.0 (the "License"); 316 | you may not use this file except in compliance with the License. 317 | You may obtain a copy of the License at 318 | 319 | http://www.apache.org/licenses/LICENSE-2.0 320 | 321 | Unless required by applicable law or agreed to in writing, software 322 | distributed under the License is distributed on an "AS IS" BASIS, 323 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 324 | See the License for the specific language governing permissions and 325 | limitations under the License. 326 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) 72 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | 42 | ------------------------------------------------------------------------------- 43 | SOFTWARE DISTRIBUTED WITH CUT CROSS ENTROPY: 44 | 45 | The Cut Cross Entropy software includes a number of subcomponents with separate 46 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.md. 47 | ------------------------------------------------------------------------------- 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Cut Your Losses in Large-Vocabulary Language Models 2 | 3 | This software project accompanies the research paper: 4 | **[Cut Your Losses in Large-Vocabulary Language Models](https://arxiv.org/abs/2411.09009)**, 5 | *Erik Wijmans, Brody Huval, Alexander Hertzberg, Vladlen Koltun, and Philipp Krähenbühl*. 6 | 7 | ![](assets/cce_figure.png) 8 | 9 | As language models grow ever larger, so do their vocabularies. This has shifted the memory footprint of LLMs during training disproportionately to one single layer: the cross-entropy in the loss computation. Cross-entropy builds up a logit matrix with entries for each pair of input tokens and vocabulary items and, for small models, consumes an order of magnitude more memory than the rest of the LLM combined. We propose Cut Cross-Entropy (CCE), a method that computes the cross-entropy loss without materializing the logits for all tokens into global memory. Rather, CCE only computes the logit for the correct token and evaluates the log-sum-exp over all logits on the fly. We implement a custom kernel that performs the matrix multiplications and the log-sum-exp reduction over the vocabulary in flash memory, making global memory consumption for the cross-entropy computation negligible. This has a dramatic effect. Taking the Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss computation from 24 GB to 1 MB, and the total training-time memory consumption of the classifier head from 28 GB to 1 GB. To improve the throughput of CCE, we leverage the inherent sparsity of softmax and propose to skip elements of the gradient computation that have a negligible (i.e., below numerical precision) contribution to the gradient. Experiments demonstrate that the dramatic reduction in memory consumption is accomplished without sacrificing training speed or convergence. 10 | 11 | ## Getting started 12 | 13 | **Requirements** 14 | 15 | 1. Python 3.10+ 16 | 2. PyTorch 2.4+ 17 | 3. Triton 3.0+ 18 | 4. Ampere (or newer) GPU 19 | 20 | 21 | **Note:** For operating systems that are not supported by Triton (e.g., MacOS), we include a highly optimized version of 22 | linear-cross-entropy using `torch.compile`. This implementation will be set to the default on MacOS. 23 | 24 | ### Basic usage 25 | 26 | **Installation** 27 | ```bash 28 | pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git" 29 | ``` 30 | 31 | **Usage** 32 | 33 | ```python 34 | from cut_cross_entropy import linear_cross_entropy 35 | 36 | embeddings = model.compute_embedding(inputs) 37 | classifier = model.get_classifier_weights() 38 | 39 | loss = linear_cross_entropy(embeddings, classifier, labels) 40 | ``` 41 | 42 | In causal language modeling, it is common that the model embeddings and labels need to be shifted 43 | such that the model predicts the next token. 44 | 45 | ```python 46 | from cut_cross_entropy import linear_cross_entropy 47 | 48 | embeddings = model.compute_embedding(inputs) 49 | classifier = model.get_classifier_weights() 50 | 51 | shift_embeddings = embeddings[..., :-1, :].flatten(0, -2) 52 | shift_labels = labels[..., 1:] 53 | 54 | manual_shift_loss = linear_cross_entropy(shift_embeddings, classifier, shift_labels) 55 | ``` 56 | 57 | Instead, pass `shift=1` to perform this computation without allocating the shift_embeddings matrix. 58 | ```python 59 | from cut_cross_entropy import linear_cross_entropy 60 | 61 | embeddings = model.compute_embedding(inputs) 62 | classifier = model.get_classifier_weights() 63 | 64 | # This is the same as manual_shift_loss above 65 | auto_shift_loss = linear_cross_entropy(embeddings, classifier, labels, shift=1) 66 | ``` 67 | 68 | We also provide a highly optimized implementation of linear-cross-entropy loss using `torch.compile`. 69 | This is a good option 70 | for scenarios where speed is the primary goal and the model has a relatively small vocabulary compared to its 71 | hidden dimension (when |V| >> D, `cce` will both save memory _and_ be faster). 72 | This option also works on the CPU and older GPUs, making it useful for testing. 73 | 74 | ```python 75 | from cut_cross_entropy import linear_cross_entropy 76 | 77 | embeddings = model.compute_embedding(inputs) 78 | classifier = model.get_classifier_weights() 79 | 80 | loss = linear_cross_entropy(embeddings, classifier, labels, ..., impl="torch_compile") 81 | ``` 82 | 83 | 84 | There are several other implementations available depending on your needs. 85 | 86 | | impl | Description | 87 | |------|-------------| 88 | | cce | The CCE implementation as described in the paper. This is may be the fastest and uses the least amount of memory. Generally recommended to start here. | 89 | | torch_compile | A highly optimized `torch.compile` implementation. This is typically the fastest but uses the most amount of memory. Good as a reference and for systems that don't support Triton. | 90 | | cce_kahan | Uses Kahan summation (or fp32) to improve numerical precision. This comes at the cost of more memory usage (albeit only a temporary buffer in the backward pass). This is useful for long sequence lengths or if the model is particularly sensitive to numerical imprecision. 91 | | cce_kahan_full_c | Same as cce_kahan and removes gradient filtering on the classifier gradient. This is useful for pretraining but will be slower. 92 | | cce_kahan_full_c_full_e (cce_exact) | This additionally removes gradient filtering from the embedding gradient. This is useful as a reference point/sanity check. | 93 | 94 | 95 | ### Vocabulary Parallelism 96 | 97 | We also support computing linear cross-entropy loss for classifier weights sharded 98 | along the vocabulary dimensions. To use this, provided a `VocabParallelOptions` instance 99 | to `linear_cross_entropy`. This takes 3 parameters, the `start` and `stop` indices of this rank's 100 | shard, and the `torch.distributed.ProcessGroup` for this rank's vocab parallel group. 101 | 102 | 103 | 104 | ```python 105 | import torch 106 | 107 | from cut_cross_entropy import linear_cross_entropy, VocabParallelOptions 108 | 109 | # The vocab parallel group for this rank. 110 | # This group can be created/retrieved in many different ways, 111 | # for instance, 112 | # torch.distributed.new_group(...) 113 | # device_mesh.get_group(mesh_dim="model_parallel") 114 | # etc 115 | vp_group = ... 116 | 117 | 118 | embeddings = model.compute_embedding(inputs) 119 | vp_classifier = model.get_classifier_weights() 120 | 121 | vp_start, vp_stop = model.get_classifier_range() 122 | vp_opts = VocabParallelOptions(vp_start, vp_stop, group=vp_group) 123 | 124 | # alternatively, there is an option to create this 125 | # by linearly dividing the vocab across ranks 126 | vp_opts = VocabParallelOptions.from_vocab(model.vocab_size, group=vp_group) 127 | 128 | # All ranks in the vocab parallel group will return the same loss 129 | loss = linear_cross_entropy(embeddings, vp_classifier, labels, ..., 130 | vocab_parallel_options=vp_opts) 131 | 132 | loss.backward() 133 | 134 | # All ranks will compute the same embeddings.grad, but each rank will have only the classifier gradient 135 | # corresponding to its part of the full classifier matrix (as defined by vp_classifier). 136 | ``` 137 | 138 | 139 | 140 | ### Computing Related Quantities 141 | 142 | `linear_cross_entropy` can be used as an efficient way to compute the negative log likelihood 143 | of a specified token. This can be used to compute various quantities. 144 | 145 | 146 | ```python 147 | from cut_cross_entropy import linear_cross_entropy 148 | 149 | 150 | # linear_cross_entropy computes negative log likelihood for a target token 151 | nll = linear_cross_entropy(embeddings, classifier, target_token, reduction="none") 152 | 153 | # Perplexity 154 | ppl = torch.exp(nll.mean(-1)) 155 | 156 | # DPO (beta and reference omitted) 157 | dpo_loss = -F.logsigmoid(nll[dispreferred].sum(-1) - nll[preferred].sum(-1)) 158 | 159 | # PPO 160 | ppo_loss = -torch.minimum(toch.exp(-nll - old_logp) * adv, adv + eps * adv.abs()) 161 | ``` 162 | 163 | 164 | ### Generalized Usage 165 | 166 | While we have discussed using CCE in the context of large language models, the only constraint 167 | to use CCE is that loss can be formulated using something that resembles following: 168 | 169 | ```python 170 | logits = X @ A.T + b # (b is an optional bias) 171 | loss = F.cross_entropy(logits.float(), targets) 172 | ``` 173 | 174 | Given that format, CCE can then be used as 175 | ```python 176 | loss = linear_cross_entropy(X, A, target_token, bias=b) 177 | ``` 178 | 179 | This is a very general and encompasses vision models, contrastive losses, e.g. CLIP, etc. 180 | 181 | 182 | ### Transformers Integration 183 | 184 | **Installation** 185 | 186 | Install cut-cross-entropy with transformers dependencies 187 | ```bash 188 | pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git" 189 | ``` 190 | 191 | **Usage** 192 | 193 | If you are using transformers, you can patch transformers to use CCE directly. Note that 194 | logits will no longer be returned (`None` will be returned instead). 195 | ```python 196 | from cut_cross_entropy.transformers import cce_patch 197 | 198 | cce_patch("llama") 199 | 200 | # or 201 | 202 | model = ... 203 | model = cce_patch(model) 204 | ``` 205 | 206 | We currently support the Llama, Phi3, Mistral, and Gemma2 families of models. 207 | 208 | `cce_patch` takes two options. The first is the linear-cross-entropy implementation to use. Currently `"cce"` or `"torch_compile"`. 209 | 210 | The second 211 | is the loss reduction. We support `"mean"`, `"sum"`, and `"none"`, that mirror their PyTorch counterpart. 212 | `"mean"` is the default and what the transformers trainer API expects. 213 | However, 214 | `"none"` in particular can enable for efficient computation of quantities based on the loss. 215 | 216 | For example, the following efficiently computes the perplexity of a batch of sequences: 217 | ```python 218 | import transformers 219 | 220 | from cut_cross_entropy.transformers import cce_patch 221 | 222 | 223 | model = transformers.AutoModelForCausalLM.from_pretrained(...) 224 | 225 | model = cce_patch(model, reduction="none") 226 | 227 | labels = input_ids.clone() 228 | labels[~attention_mask] = -100 # -100 is the ignore index for PyTorch and CCE. 229 | 230 | outputs = model(input_ids, attention_mask, labels=labels) 231 | 232 | loss = outputs[0] # A (B, T - 1) tensor because reduction="none". T - 1 because the first input token has 233 | # no loss. 234 | 235 | ppl = torch.exp( 236 | # [:, 1:] because the first token has no loss 237 | loss.sum(1) / (labels[:, 1:] != -100).count_nonzero(dim=1) 238 | ).mean() # Average perplexity over the batch 239 | ``` 240 | 241 | 242 | 243 | ### Training and reproducing the benchmark results 244 | 245 | We provide a training in `training/train.py`. 246 | 247 | **Installation** 248 | ```bash 249 | pip install "cut-cross-entropy[all] @ git+https://github.com/apple/ml-cross-entropy.git" 250 | ``` 251 | 252 | **Training** 253 | 254 | Use `scripts/train.sh` to train a full model. 255 | 256 | **Benchmarking** 257 | 258 | The benchmark script can be run via `python -m benchmark`. 259 | 260 | Expected output with A100 SMX4, PyTorch 2.4.1, and CUDA 12.4. 261 | 262 | ``` 263 | method kind runtime_ms op_mem_mb test_data 264 | 0 cce loss-fw 46.4 1.1 gemma2 265 | 1 torch_compile loss-fw 49.9 4000.1 gemma2 266 | 2 baseline loss-fw 81.9 24000.0 gemma2 267 | 3 cce loss-bw 89.3 1163.0 gemma2 268 | 4 torch_compile loss-bw 92.3 12000.0 gemma2 269 | 5 baseline loss-bw 122.4 16000.0 gemma2 270 | 6 cce loss-fw-bw 134.8 1164.0 gemma2 271 | 7 torch_compile loss-fw-bw 144.0 16000.1 gemma2 272 | 8 baseline loss-fw-bw 208.8 28000.0 gemma2 273 | ``` 274 | 275 | ### Development 276 | 277 | If dependencies are installed locally, `cut-cross-entropy` will work without a pip install as long as `python` is executed in the root path of the github repo. 278 | 279 | To install directly from the github repo, either use an (editable) install or manipulate PYTHONPATH, e.g. 280 | 281 | ```bash 282 | pip install -e ".[dev]" 283 | 284 | # or 285 | pip install ".[dev]" 286 | 287 | # or 288 | export PYTHONPATH=/path/to/ml-cross-entropy:${PYTHONPATH} 289 | ``` 290 | 291 | ## Citation 292 | 293 | ``` 294 | @inproceedings{wijmans2025cut, 295 | author = {Erik Wijmans and 296 | Brody Huval and 297 | Alexander Hertzberg and 298 | Vladlen Koltun and 299 | Philipp Kr\"ahenb\"uhl}, 300 | title = {Cut Your Losses in Large-Vocabulary Language Models}, 301 | booktitle = {International Conference on Learning Representations}, 302 | year = {2025}, 303 | } 304 | ``` 305 | 306 | 307 | ## License 308 | This sample code is released under the [LICENSE](LICENSE) terms. 309 | 310 | ## Acknowledgements 311 | 312 | Our codebase is built using multiple opensource contributions, please see [Acknowledgements](ACKNOWLEDGEMENTS.md) for more details. 313 | 314 | Please check the paper for a complete list of references and datasets used in this work. 315 | -------------------------------------------------------------------------------- /assets/cce_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-cross-entropy/b616b222976b235647790a16d0388338b9e18941/assets/cce_figure.png -------------------------------------------------------------------------------- /benchmark/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import functools 3 | import gc 4 | import time 5 | from dataclasses import asdict 6 | from pathlib import Path 7 | from typing import Any 8 | 9 | import pandas as pd 10 | import torch 11 | import torch.nn.functional as F 12 | import tqdm 13 | from fire import Fire 14 | 15 | from cut_cross_entropy import linear_cross_entropy 16 | from cut_cross_entropy.constants import IGNORE_INDEX 17 | 18 | from . import data, memory 19 | 20 | 21 | def baseline( 22 | e: torch.Tensor, 23 | c: torch.Tensor, 24 | targets: torch.Tensor, 25 | softcap: float | None = None, 26 | ignore_index: int = IGNORE_INDEX, 27 | reduction: str = "mean", 28 | ) -> torch.Tensor: 29 | logits = e @ c.T 30 | 31 | if softcap is not None: 32 | logits = torch.tanh(logits / softcap) * softcap 33 | 34 | return F.cross_entropy(logits.float(), targets, ignore_index=ignore_index, reduction=reduction) 35 | 36 | 37 | def clear_grad_fn(E, C, *args, **kwargs): 38 | E.grad = C.grad = None 39 | 40 | 41 | def benchmark( 42 | methods: list[str] | str | None = None, 43 | test_data: list[str] | str | None = "gemma2", 44 | n_iteration: int = 50, 45 | n_rep: int = 1, 46 | dtype: str = "bfloat16", 47 | output: str | None = None, 48 | kinds: list[str] | str | None = None, 49 | softcap: float | None = None, 50 | ): 51 | torch.set_float32_matmul_precision("high") 52 | 53 | if methods is None: 54 | methods = ["cce", "cce_full_c", "torch_compile", "baseline", "cce_exact"] 55 | elif isinstance(methods, str): 56 | methods = methods.split(",") 57 | 58 | if kinds is None: 59 | kinds = ["loss-fw", "loss-bw", "loss-fw-bw"] 60 | elif isinstance(kinds, str): 61 | kinds = kinds.split(",") 62 | 63 | if test_data is None: 64 | test_data = ["gemma2", "llama3", "mistral-nemo", "phi3.5"] 65 | elif isinstance(test_data, str): 66 | test_data = test_data.split(",") 67 | 68 | dtype = getattr(torch, dtype) 69 | 70 | all_stats = [] 71 | 72 | for this_test_data in tqdm.tqdm(test_data, desc="Data source", disable=len(test_data) == 1): 73 | gen = data.generator(this_test_data) 74 | for rep in tqdm.trange(n_rep + 1, desc="Repetition"): 75 | D = gen(dtype=dtype) 76 | for kind in tqdm.tqdm(kinds, desc="Benchmark kind", disable=len(kinds) == 1): 77 | E, C, T = D.embedding, D.classifier, D.targets 78 | 79 | this_softcap = softcap if softcap is not None else D.softcap 80 | 81 | kwargs: dict[str, Any] = {"softcap": this_softcap} 82 | if kind == "loss-fw": 83 | E.requires_grad_(True) 84 | C.requires_grad_(True) 85 | args = (E, C, T) 86 | elif kind in {"loss-bw", "loss-fw-bw"}: 87 | E.requires_grad_(True) 88 | C.requires_grad_(True) 89 | 90 | args = (E, C, T) 91 | kwargs["backward"] = True 92 | kwargs["forward"] = kind == "loss-fw-bw" 93 | kwargs["pre_fn"] = clear_grad_fn 94 | else: 95 | raise ValueError(f"Unknown kind {kind=}") 96 | 97 | for m in tqdm.tqdm(methods, desc="Method", leave=False): 98 | if m in "liger" and kind.startswith("lse"): 99 | continue 100 | 101 | # warmup (it==0) 102 | stats = memory.Stats.measure( 103 | baseline 104 | if m == "baseline" 105 | else functools.partial(linear_cross_entropy, impl=m), 106 | *args, 107 | n_iteration=n_iteration if rep > 0 else 1, 108 | **kwargs, 109 | ) 110 | 111 | if rep > 0 or n_rep == 0: 112 | this_stats = { 113 | "method": m, 114 | "kind": kind, 115 | } | asdict(stats) 116 | 117 | this_stats["test_data"] = this_test_data 118 | 119 | all_stats.append(this_stats) 120 | 121 | torch.cuda.synchronize() 122 | time.sleep(1) 123 | gc.collect() 124 | torch.cuda.empty_cache() 125 | time.sleep(1) 126 | 127 | all_stats = pd.DataFrame(all_stats) 128 | pd.options.display.float_format = "{:.1f}".format 129 | print(all_stats) 130 | if output is not None: 131 | output_path = Path(output) 132 | output_path.parent.mkdir(parents=True, exist_ok=True) 133 | all_stats.to_csv(output_path) 134 | 135 | 136 | if __name__ == "__main__": 137 | Fire(benchmark) 138 | -------------------------------------------------------------------------------- /benchmark/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import functools 3 | from collections.abc import Callable 4 | 5 | from .data import Data 6 | from .models import generate_test_data_otf, load_model 7 | from .randn import generate as randn_generate 8 | 9 | generators: dict[str, Callable[..., Data]] = { 10 | "llama3": functools.partial( 11 | generate_test_data_otf, 12 | "meta-llama/Meta-Llama-3-8B-Instruct", 13 | ), 14 | "llama3.2-1": functools.partial( 15 | generate_test_data_otf, 16 | "meta-llama/Llama-3.2-1B-Instruct", 17 | ), 18 | "llama3.2-3": functools.partial( 19 | generate_test_data_otf, 20 | "meta-llama/Llama-3.2-3B-Instruct", 21 | ), 22 | "llama3-70": functools.partial( 23 | generate_test_data_otf, 24 | "meta-llama/Meta-Llama-3-70B-Instruct", 25 | ), 26 | "gemma2": functools.partial(generate_test_data_otf, "google/gemma-2-2b-it"), 27 | "gemma2-9": functools.partial(generate_test_data_otf, "google/gemma-2-9b-it"), 28 | "gemma2-27": functools.partial(generate_test_data_otf, "google/gemma-2-27b-it"), 29 | "phi3.5": functools.partial(generate_test_data_otf, "microsoft/Phi-3.5-mini-instruct"), 30 | "mistral-nemo": functools.partial( 31 | generate_test_data_otf, "mistralai/Mistral-Nemo-Instruct-2407" 32 | ), 33 | } 34 | 35 | generators = generators | { 36 | f"{k}-invalids": functools.partial(v, keep_invalids=True) for k, v in generators.items() 37 | } 38 | 39 | generators["randn"] = randn_generate 40 | 41 | all_fig1_models = [ 42 | "google/gemma-2-2b", 43 | "google/gemma-2b", 44 | "meta-llama/Llama-2-7b-chat-hf", 45 | "microsoft/Phi-3.5-mini-instruct", 46 | "meta-llama/Meta-Llama-3-8B", 47 | "google/gemma-2-9b", 48 | "meta-llama/Meta-Llama-3-70B", 49 | "meta-llama/Llama-2-13b-chat-hf", 50 | "openai-community/gpt2", 51 | "mistralai/Mistral-7B-v0.1", 52 | "microsoft/phi-1_5", 53 | "EleutherAI/gpt-neo-1.3B", 54 | "EleutherAI/gpt-neo-2.7B", 55 | "mistralai/Mixtral-8x7B-Instruct-v0.1", 56 | "google/gemma-2-27b-it", 57 | "microsoft/Phi-3-medium-128k-instruct", 58 | ] 59 | 60 | 61 | def generator(name: str) -> Callable[..., Data]: 62 | if name not in generators: 63 | raise ValueError(f"Data generator {name!r} not found.") 64 | 65 | load_model.cache_clear() 66 | return generators[name] 67 | -------------------------------------------------------------------------------- /benchmark/data/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | 6 | 7 | @dataclass 8 | class Data: 9 | embedding: torch.Tensor 10 | classifier: torch.Tensor 11 | targets: torch.Tensor 12 | softcap: float | None = None 13 | 14 | @property 15 | def required_storage(self) -> float: 16 | return ( 17 | self.embedding.element_size() * self.embedding.numel() 18 | + self.classifier.element_size() * self.classifier.numel() 19 | ) 20 | -------------------------------------------------------------------------------- /benchmark/data/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import functools 3 | import random 4 | 5 | import torch 6 | import transformers 7 | from torch.utils.data import DataLoader 8 | 9 | from training.train import ( 10 | DataArguments, 11 | make_supervised_data_module, 12 | ) 13 | 14 | from .data import Data 15 | 16 | 17 | @functools.cache 18 | def generator_for_model(model_name: str) -> random.Random: 19 | return random.Random(0) 20 | 21 | 22 | @functools.lru_cache(1) 23 | def load_model(model_name: str, inference_device: torch.device): 24 | config = transformers.AutoConfig.from_pretrained(model_name) 25 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) 26 | attn_impl = "flash_attention_2" if config.model_type != "gemma2" else "eager" 27 | if config.model_type == "mistral": 28 | tokenizer.padding_side = "left" 29 | tokenizer.pad_token = "" 30 | elif config.model_type == "llama": 31 | tokenizer.pad_token = "<|reserved_special_token_0|>" 32 | 33 | causal_lm = transformers.AutoModelForCausalLM.from_pretrained( 34 | model_name, 35 | attn_implementation=attn_impl, 36 | torch_dtype=torch.bfloat16, 37 | device_map=inference_device, 38 | ).to(device=inference_device) 39 | causal_lm.eval() 40 | 41 | return causal_lm, tokenizer, config 42 | 43 | 44 | def generate_test_data_otf( 45 | model_name: str, 46 | dataset_name: str = "yahma/alpaca-cleaned", 47 | device: str = "cuda", 48 | dtype: torch.dtype = torch.bfloat16, 49 | keep_invalids: bool = False, 50 | M: int = 8 * 1024, 51 | ): 52 | default_device = torch.cuda.current_device() 53 | inference_device = torch.device("cuda", (default_device + 1) % torch.cuda.device_count()) 54 | torch.cuda.set_device(inference_device) 55 | 56 | causal_lm, tokenizer, config = load_model(model_name, inference_device) 57 | 58 | data_module = make_supervised_data_module( 59 | tokenizer, 60 | DataArguments(dataset_name), 61 | seed=generator_for_model(model_name).randint(0, 2**20), 62 | uses_system_prompt=config.model_type not in ("gemma2",), 63 | ) 64 | 65 | generator = torch.Generator().manual_seed(generator_for_model(model_name).randint(0, 2**20)) 66 | 67 | dl = DataLoader( 68 | data_module["train_dataset"], 69 | batch_size=8, 70 | shuffle=True, 71 | num_workers=1, 72 | pin_memory=True, 73 | generator=generator, 74 | collate_fn=data_module["data_collator"], 75 | ) 76 | 77 | decoder = causal_lm.get_decoder() 78 | inputs_l = [] 79 | labels_l = [] 80 | 81 | with torch.inference_mode(): 82 | for batch in dl: 83 | batch = {k: v.to(device=inference_device) for k, v in batch.items()} 84 | 85 | outputs = decoder( 86 | input_ids=batch["input_ids"], 87 | attention_mask=batch["attention_mask"], 88 | position_ids=batch["position_ids"], 89 | ) 90 | 91 | hidden = outputs[0] 92 | 93 | labels = batch["labels"] 94 | shift_hidden = hidden[..., :-1, :].contiguous() 95 | shift_labels = labels[..., 1:].contiguous() 96 | shift_hidden = shift_hidden.view(-1, config.hidden_size) 97 | shift_labels = shift_labels.view(-1) 98 | 99 | if not keep_invalids: 100 | valids = (shift_labels != -100).nonzero(as_tuple=True) 101 | shift_hidden = shift_hidden[valids] 102 | shift_labels = shift_labels[valids] 103 | 104 | inputs_l.append(shift_hidden) 105 | labels_l.append(shift_labels) 106 | 107 | if sum(v.numel() for v in labels_l) >= M: 108 | break 109 | 110 | inputs = torch.cat(inputs_l)[0:M].clone().contiguous() 111 | labels = torch.cat(labels_l)[0:M].clone().contiguous() 112 | w = causal_lm.get_output_embeddings().weight.detach().clone().contiguous() 113 | torch.cuda.set_device(default_device) 114 | 115 | return Data( 116 | inputs.to(device=default_device), 117 | w.to(device=default_device), 118 | labels.to(device=default_device), 119 | softcap=getattr(config, "final_logit_softcapping", None), 120 | ) 121 | -------------------------------------------------------------------------------- /benchmark/data/randn.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import torch 3 | 4 | from .data import Data 5 | 6 | 7 | def generate( 8 | device: str = "cuda", 9 | dtype: torch.dtype = torch.bfloat16, 10 | M: int = 256, # 16 * 1024, 11 | N: int = 512, 12 | D: int = 128, 13 | fraction_invalid_labels: float = 0.2, 14 | ) -> Data: 15 | """Random data generation 16 | 17 | Args: 18 | ---- 19 | device (str, optional): Cuda device. Defaults to "cuda". 20 | dtype (torch.dtype, optional): Tensor type (torch.float32 or torch.bfloat16). Defaults to torch.bfloat16. 21 | M (int, optional): Sequence length. Defaults to 16000. 22 | N (int, optional): Vocabulary size. Defaults to 128000. 23 | D (int, optional): Embedding dimension. Defaults to 4096. 24 | 25 | Returns: 26 | ------- 27 | Data: A data sample. 28 | 29 | """ 30 | W = torch.randn(N, D, device=device, dtype=dtype, requires_grad=False) / D**0.25 31 | x = torch.randn(M, D, device=device, dtype=dtype, requires_grad=False) / D**0.25 32 | # Get some values that are non-zero in expectation 33 | W[:M] = x[: min(N, M)] 34 | targets = torch.randint(0, N, size=(M,), device=device) 35 | # targets[0 : int(M * fraction_invalid_labels)] = -100 36 | targets = targets[torch.randperm(M, device=targets.device)] 37 | 38 | return Data(x, W, targets) 39 | -------------------------------------------------------------------------------- /benchmark/memory.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import gc 3 | from collections.abc import Iterable 4 | from dataclasses import dataclass 5 | from typing import Callable, TypeVar 6 | 7 | import torch 8 | from typing_extensions import Self 9 | 10 | T = TypeVar("T") 11 | 12 | 13 | def mem_delta(s1, s0): 14 | return s1["allocated_bytes.all.peak"] - s0["allocated_bytes.all.peak"] 15 | 16 | 17 | def calc_tensor_memory_usage(t: torch.Tensor | Iterable[torch.Tensor]) -> float: 18 | mem = 0.0 19 | if isinstance(t, torch.Tensor): 20 | mem += t.numel() * torch.finfo(t.dtype).bits / 8 21 | elif isinstance(t, Iterable): 22 | mem += sum(calc_tensor_memory_usage(v) for v in t) 23 | else: 24 | raise RuntimeError(f"Unknown return type {type(t)}") 25 | 26 | return mem 27 | 28 | 29 | def detach_rval(t: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: 30 | return t.detach() 31 | 32 | 33 | @dataclass(frozen=True) 34 | class Stats: 35 | runtime_ms: float # in seconds 36 | op_mem_mb: float # in GB 37 | 38 | @classmethod 39 | def measure( 40 | cls, 41 | f: Callable[..., torch.Tensor], 42 | *args, 43 | n_iteration: int = 1, 44 | forward: bool = True, 45 | backward: bool = False, 46 | pre_fn: Callable[..., None] | None = None, 47 | **kwds, 48 | ) -> Self: 49 | if pre_fn is not None: 50 | pre_fn(*args, **kwds) 51 | 52 | torch.cuda.synchronize() 53 | gc.collect() 54 | torch.cuda.empty_cache() 55 | 56 | mem_usages = 0 57 | torch.cuda.reset_peak_memory_stats() 58 | s0 = torch.cuda.memory_stats() 59 | t = f(*args, **kwds) 60 | 61 | torch.cuda.synchronize() 62 | if forward: 63 | s1 = torch.cuda.memory_stats() 64 | 65 | if backward and not forward: 66 | torch.cuda.reset_peak_memory_stats() 67 | s0 = torch.cuda.memory_stats() 68 | 69 | if backward: 70 | if t.numel() > 1: 71 | t = t.mean() 72 | t.backward() 73 | 74 | torch.cuda.synchronize() 75 | if backward: 76 | s1 = torch.cuda.memory_stats() 77 | 78 | mem_usages = mem_delta(s1, s0) 79 | 80 | # Run the function to benchmark 81 | # rval = None 82 | all_cuda_events: list[tuple[torch.cuda.Event, torch.cuda.Event]] = [] 83 | for _ in range(n_iteration): 84 | if pre_fn is not None: 85 | pre_fn(*args, **kwds) 86 | 87 | start = torch.cuda.Event(enable_timing=True) 88 | 89 | if forward: 90 | start.record(torch.cuda.current_stream()) 91 | 92 | t = f(*args, **kwds) 93 | 94 | if not forward: 95 | start.record(torch.cuda.current_stream()) 96 | 97 | if backward: 98 | if t.numel() > 1: 99 | t = t.mean() 100 | t.backward() 101 | 102 | end = torch.cuda.Event(enable_timing=True) 103 | end.record(torch.cuda.current_stream()) 104 | 105 | all_cuda_events.append((start, end)) 106 | 107 | torch.cuda.synchronize() 108 | total_time = 0.0 109 | for start, end in all_cuda_events: 110 | start.synchronize() 111 | end.synchronize() 112 | total_time += start.elapsed_time(end) 113 | 114 | s1 = torch.cuda.memory_stats() 115 | 116 | return cls( 117 | total_time / n_iteration, 118 | mem_usages / 2**20, 119 | ) 120 | -------------------------------------------------------------------------------- /cut_cross_entropy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl 3 | from cut_cross_entropy.linear_cross_entropy import ( 4 | LinearCrossEntropy, 5 | linear_cross_entropy, 6 | ) 7 | from cut_cross_entropy.vocab_parallel import VocabParallelOptions 8 | 9 | __all__ = [ 10 | "LinearCrossEntropy", 11 | "LinearCrossEntropyImpl", 12 | "linear_cross_entropy", 13 | "VocabParallelOptions", 14 | ] 15 | 16 | 17 | __version__ = "25.5.1" 18 | -------------------------------------------------------------------------------- /cut_cross_entropy/cce.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from dataclasses import dataclass 3 | from typing import cast 4 | 5 | import torch 6 | 7 | from cut_cross_entropy.cce_backward import cce_backward_kernel 8 | from cut_cross_entropy.cce_lse_forward import cce_lse_forward_kernel 9 | from cut_cross_entropy.constants import IGNORE_INDEX 10 | from cut_cross_entropy.doc import CCE_OPTS_DOC, LINEAR_CROSS_ENTROPY_DOC, add_doc_start 11 | from cut_cross_entropy.indexed_dot import indexed_neg_dot_forward_kernel 12 | from cut_cross_entropy.utils import ( 13 | _build_flat_valids, 14 | _handle_eps, 15 | handle_reduction_none, 16 | ) 17 | from cut_cross_entropy.vocab_parallel.utils import ( 18 | VocabParallelOptions, 19 | vp_reduce_correct_logit, 20 | vp_reduce_lse, 21 | ) 22 | 23 | 24 | @dataclass 25 | class CCEParams: 26 | targets: torch.Tensor 27 | valids: torch.Tensor | None 28 | softcap: float | None 29 | reduction: str 30 | filter_eps: float | None 31 | shift: int 32 | batch_shape: torch.Size 33 | accum_e_fp32: bool 34 | accum_c_fp32: bool 35 | filter_e_grad: bool 36 | filter_c_grad: bool 37 | vocab_parallel_options: VocabParallelOptions | None 38 | 39 | 40 | @torch.compile(fullgraph=True) 41 | def sort_logit_avg(logit_avg: torch.Tensor) -> torch.Tensor: 42 | return torch.argsort(logit_avg).to(torch.int32) 43 | 44 | 45 | class LinearCrossEntropyFunction(torch.autograd.Function): 46 | @staticmethod 47 | def forward( 48 | ctx, 49 | e: torch.Tensor, 50 | c: torch.Tensor, 51 | bias: torch.Tensor | None, 52 | params: CCEParams, 53 | ) -> torch.Tensor: 54 | needs_grad = e.requires_grad or c.requires_grad 55 | return_logit_avg = needs_grad and params.filter_eps is not None 56 | 57 | ret = cce_lse_forward_kernel( 58 | e=e, 59 | c=c, 60 | bias=bias, 61 | valids=params.valids, 62 | softcap=params.softcap, 63 | return_logit_avg=return_logit_avg, 64 | ) 65 | if return_logit_avg: 66 | assert isinstance(ret, tuple) 67 | lse, logit_avg = ret 68 | else: 69 | assert isinstance(ret, torch.Tensor) 70 | lse = ret 71 | logit_avg = None 72 | 73 | if (vp_opts := params.vocab_parallel_options) is not None: 74 | lse = vp_reduce_lse(lse, pg=vp_opts.group) 75 | 76 | if params.valids is not None: 77 | targets = params.targets[params.valids + params.shift] 78 | else: 79 | targets = params.targets 80 | 81 | vp_valids = ( 82 | ((targets >= vp_opts.start) & (targets < vp_opts.stop)).nonzero().to(torch.int32) 83 | ) 84 | assert vp_valids.size(1) == 1 85 | vp_valids = vp_valids.squeeze(-1) 86 | 87 | if params.valids is not None: 88 | neg_dot_valids = params.valids[vp_valids] 89 | else: 90 | neg_dot_valids = vp_valids 91 | 92 | neg_dot_targets = params.targets - vp_opts.start 93 | else: 94 | neg_dot_valids = params.valids 95 | neg_dot_targets = params.targets 96 | vp_valids = None 97 | 98 | neg_dot = indexed_neg_dot_forward_kernel( 99 | e=e, 100 | c=c, 101 | inds=neg_dot_targets, 102 | bias=bias, 103 | shift=params.shift, 104 | valids=neg_dot_valids, 105 | softcap=params.softcap, 106 | out_dtype=lse.dtype, 107 | ) 108 | 109 | if params.vocab_parallel_options is not None: 110 | global_neg_dot = neg_dot.new_zeros(lse.size()) 111 | assert vp_valids is not None 112 | global_neg_dot[vp_valids] = neg_dot 113 | 114 | neg_dot = vp_reduce_correct_logit( 115 | global_neg_dot, pg=params.vocab_parallel_options.group, dtype=lse.dtype 116 | ) 117 | 118 | nll = neg_dot.add_(lse) 119 | 120 | reduction = params.reduction 121 | if reduction == "mean": 122 | loss = nll.mean() 123 | elif reduction == "sum": 124 | loss = nll.sum() 125 | elif reduction == "none": 126 | loss = handle_reduction_none(params.batch_shape, params.valids, params.shift, nll) 127 | else: 128 | raise ValueError(f"Unknown reduction {reduction}") 129 | 130 | ctx.save_for_backward(e, c, bias, lse, params.targets, params.valids, logit_avg) 131 | ctx.params = params 132 | 133 | return loss 134 | 135 | @staticmethod 136 | def backward( 137 | ctx, grad_out: torch.Tensor 138 | ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, None]: 139 | e, c, bias, lse, targets, valids, logit_avg = ctx.saved_tensors 140 | 141 | if logit_avg is not None: 142 | vocab_ordering = sort_logit_avg(logit_avg) 143 | else: 144 | vocab_ordering = None 145 | 146 | params = cast(CCEParams, ctx.params) 147 | reduction = params.reduction 148 | if reduction == "mean": 149 | grad_scale = 1 / lse.numel() 150 | elif reduction == "sum": 151 | grad_scale = 1.0 152 | elif reduction == "none": 153 | grad_scale = 1.0 154 | grad_out = grad_out.view(-1) 155 | else: 156 | raise ValueError(f"Unknown reduction {reduction}") 157 | 158 | reduce_e_grad = False 159 | pg = None 160 | if (vp_opts := params.vocab_parallel_options) is not None: 161 | is_my_target = (targets >= vp_opts.start) & (targets < vp_opts.stop) 162 | targets = torch.where( 163 | is_my_target, 164 | targets - vp_opts.start, 165 | ## NB 166 | # The backward kernel already uses 167 | # c.size(0) + 1 as the padding value to ensure that 168 | # (targets.size(0) % block_size) == 0, so for targets 169 | # that aren't in this VP rank's range, we can just consider 170 | # them as padded and all work work as expected. 171 | targets.new_full((), c.size(0) + 1), 172 | ) 173 | 174 | reduce_e_grad = vp_opts.reduce_e_grad 175 | pg = vp_opts.group 176 | 177 | de, dc, dbias = cce_backward_kernel( 178 | do=grad_out, 179 | e=e, 180 | c=c, 181 | bias=bias, 182 | lse=lse, 183 | valids=valids, 184 | softcap=params.softcap, 185 | filter_eps=params.filter_eps, 186 | targets=targets, 187 | shift=params.shift, 188 | vocab_ordering=vocab_ordering, 189 | grad_scale=grad_scale, 190 | accum_e_fp32=params.accum_e_fp32, 191 | accum_c_fp32=params.accum_c_fp32, 192 | filter_e_grad=params.filter_e_grad, 193 | filter_c_grad=params.filter_c_grad, 194 | reduce_e_grad=reduce_e_grad, 195 | pg=pg, 196 | ) 197 | 198 | return de, dc, dbias, None 199 | 200 | 201 | def linear_cross_entropy_apply( 202 | e: torch.Tensor, 203 | c: torch.Tensor, 204 | bias: torch.Tensor | None, 205 | params: CCEParams, 206 | ) -> torch.Tensor: 207 | loss = LinearCrossEntropyFunction.apply(e, c, bias, params) 208 | assert isinstance(loss, torch.Tensor) 209 | 210 | if params.shift != 0 and params.reduction == "none": 211 | loss = loss[..., params.shift :] 212 | 213 | return loss 214 | 215 | 216 | @add_doc_start(LINEAR_CROSS_ENTROPY_DOC) 217 | @add_doc_start(*(doc_str + "\n" for doc_str in CCE_OPTS_DOC)) 218 | def cce_linear_cross_entropy( 219 | e: torch.Tensor, 220 | c: torch.Tensor, 221 | targets: torch.Tensor, 222 | bias: torch.Tensor | None = None, 223 | ignore_index: int = IGNORE_INDEX, 224 | softcap: float | None = None, 225 | reduction: str = "mean", 226 | shift: bool | int = 0, 227 | filter_eps: float | str | None = "auto", 228 | accum_e_fp32: bool = False, 229 | accum_c_fp32: bool = False, 230 | filter_e_grad: bool = True, 231 | filter_c_grad: bool = True, 232 | vocab_parallel_options: VocabParallelOptions | None = None, 233 | ) -> torch.Tensor: 234 | assert e.size()[0:-1] == targets.size() 235 | assert e.size(-1) == c.size(1) 236 | if not torch.cuda.is_bf16_supported(): 237 | raise RuntimeError( 238 | "Cut Cross Entropy requires an ampere GPU or newer. " 239 | "Consider using torch_compile_linear_cross_entropy for scenarios where one is not available." 240 | ) 241 | 242 | batch_shape = targets.size() 243 | 244 | e = e.contiguous() 245 | targets = targets.contiguous() 246 | 247 | shift = int(shift) 248 | valids = _build_flat_valids(targets, ignore_index, shift) 249 | 250 | e = e.flatten(0, -2) 251 | targets = targets.flatten() 252 | 253 | if (targets.data_ptr() % 16) != 0: 254 | targets = torch.nn.functional.pad(targets, (0, 1))[:-1] 255 | 256 | assert (targets.data_ptr() % 16) == 0 257 | cce_params = CCEParams( 258 | targets, 259 | valids, 260 | softcap, 261 | reduction, 262 | _handle_eps(filter_eps, e.dtype), 263 | shift, 264 | batch_shape, 265 | accum_e_fp32, 266 | accum_c_fp32, 267 | filter_e_grad=filter_e_grad and filter_eps is not None, 268 | filter_c_grad=filter_c_grad and filter_eps is not None, 269 | vocab_parallel_options=vocab_parallel_options, 270 | ) 271 | 272 | return linear_cross_entropy_apply(e, c, bias, cce_params) 273 | -------------------------------------------------------------------------------- /cut_cross_entropy/cce_backward.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import torch 3 | import triton 4 | import triton.language as tl 5 | 6 | from cut_cross_entropy.tl_autotune import cce_backward_autotune 7 | from cut_cross_entropy.tl_utils import ( 8 | b_bin_fn, 9 | is_triton_greater_or_equal_3_2_0, 10 | tl_and_reduce_fn, 11 | tl_lock_add, 12 | tl_lock_kahan_sum, 13 | tl_softcapping, 14 | tl_softcapping_grad, 15 | ) 16 | from cut_cross_entropy.vocab_parallel.utils import vp_reduce_e_grad 17 | 18 | 19 | @triton.jit 20 | def _mm_backward( 21 | do, 22 | da_ptrs, 23 | dac_ptrs, 24 | partial_mask_a, 25 | da_lock_ptr, 26 | n_locks, 27 | b_ptrs, 28 | partial_mask_b, 29 | stride_ad, 30 | stride_bd, 31 | D, 32 | BLOCK_D: tl.constexpr, 33 | EVEN_D: tl.constexpr, 34 | USE_KAHAN: tl.constexpr, 35 | ): 36 | d_inds = tl.arange(0, BLOCK_D)[None, :].to(tl.int64) 37 | 38 | b_ptrs = b_ptrs + d_inds * stride_bd 39 | da_ptrs = da_ptrs + d_inds * stride_ad 40 | if USE_KAHAN: 41 | dac_ptrs = dac_ptrs + d_inds * stride_ad 42 | 43 | for d in range(0, tl.cdiv(D, BLOCK_D)): 44 | if EVEN_D: 45 | mask = partial_mask_b 46 | else: 47 | mask = partial_mask_b & (d_inds < (D - d * BLOCK_D)) 48 | 49 | b = tl.load(b_ptrs, mask=mask, other=0.0) 50 | 51 | da_i = tl.dot(do, b).to(da_ptrs.dtype.element_ty) 52 | 53 | if EVEN_D: 54 | mask = partial_mask_a 55 | else: 56 | mask = partial_mask_a & (d_inds < (D - d * BLOCK_D)) 57 | 58 | lock_offset = d // tl.cdiv(D, BLOCK_D * n_locks) 59 | this_da_lock_ptr = da_lock_ptr + lock_offset 60 | 61 | if USE_KAHAN: 62 | tl_lock_kahan_sum(da_ptrs, dac_ptrs, da_i, mask, this_da_lock_ptr) 63 | else: 64 | tl_lock_add(da_ptrs, da_i, mask, this_da_lock_ptr) 65 | 66 | b_ptrs += BLOCK_D * stride_bd 67 | da_ptrs += BLOCK_D * stride_ad 68 | if USE_KAHAN: 69 | dac_ptrs += BLOCK_D * stride_ad 70 | 71 | 72 | @triton.jit 73 | def _block_is_filtered(check_val: tl.tensor, filter_eps: tl.tensor) -> tl.tensor: 74 | return tl.reduce(check_val < filter_eps, None, tl_and_reduce_fn) 75 | 76 | 77 | def _cce_backward_kernel( 78 | E, 79 | C, 80 | Bias, 81 | LSE, 82 | dOut, 83 | grad_scale, 84 | Valids, 85 | VocabOrdering, 86 | softcap, 87 | Targets, 88 | dE, 89 | dEC, 90 | dELocks, 91 | dC, 92 | dCC, 93 | dCLocks, 94 | dBias, 95 | B, 96 | D, 97 | V, 98 | BMax, 99 | n_de_locks_0, 100 | n_de_locks_1, 101 | n_dc_locks_0, 102 | n_dc_locks_1, 103 | stride_eb, 104 | stride_ed, 105 | stride_cv, 106 | stride_cd, 107 | stride_biasv, 108 | stride_vb, 109 | filter_eps, 110 | shift, 111 | B_BIN, 112 | BLOCK_B: tl.constexpr, 113 | BLOCK_V: tl.constexpr, 114 | BLOCK_D: tl.constexpr, 115 | MM_BACK_BLOCK_D: tl.constexpr, 116 | GROUP_B: tl.constexpr, 117 | EVEN_D: tl.constexpr, 118 | MM_BACK_EVEN_D: tl.constexpr, 119 | ITEM_DO: tl.constexpr, 120 | HAS_BIAS: tl.constexpr, 121 | HAS_VALIDS: tl.constexpr, 122 | HAS_VOCAB_ORDERING: tl.constexpr, 123 | FILTER_E_GRAD: tl.constexpr, 124 | FILTER_C_GRAD: tl.constexpr, 125 | HAS_TARGETS: tl.constexpr, 126 | HAS_SOFTCAP: tl.constexpr, 127 | HAS_SHIFT: tl.constexpr, 128 | KAHAN_E: tl.constexpr, 129 | KAHAN_C: tl.constexpr, 130 | COMPUTE_DC: tl.constexpr, 131 | COMPUTE_DE: tl.constexpr, 132 | COMPUTE_DBIAS: tl.constexpr, 133 | ): 134 | pid = tl.program_id(axis=0) 135 | num_b_chunks = tl.cdiv(B, BLOCK_B) 136 | num_v_chunks = tl.cdiv(V, BLOCK_V) 137 | num_v_in_group = GROUP_B * num_v_chunks 138 | group_id = pid // num_v_in_group 139 | first_pid_b = group_id * GROUP_B 140 | group_size_b = min(num_b_chunks - first_pid_b, GROUP_B) 141 | pid_b = (first_pid_b + ((pid % num_v_in_group) % group_size_b)).to(tl.int64) 142 | pid_v = ((pid % num_v_in_group) // group_size_b).to(tl.int64) 143 | 144 | offs_b = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)).to(tl.int64) 145 | if HAS_VALIDS: 146 | offs_b = tl.load(Valids + stride_vb * offs_b, mask=offs_b < B, other=BMax).to(tl.int64) 147 | 148 | offs_v = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)).to(tl.int64) 149 | if HAS_VOCAB_ORDERING: 150 | offs_v = tl.load(VocabOrdering + offs_v, mask=offs_v < V, other=V).to(tl.int64) 151 | 152 | offs_d = tl.arange(0, BLOCK_D).to(tl.int64) 153 | e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) 154 | c_ptrs = C + (offs_v[None, :] * stride_cv + offs_d[:, None] * stride_cd) 155 | 156 | accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32) 157 | for d in range(0, tl.cdiv(D, BLOCK_D)): 158 | e_mask = offs_b[:, None] < BMax 159 | if not EVEN_D: 160 | e_mask = e_mask & (offs_d[None, :] < (D - d * BLOCK_D)) 161 | 162 | e = tl.load(e_ptrs, mask=e_mask, other=0.0) 163 | 164 | c_mask = offs_v[None, :] < V 165 | if not EVEN_D: 166 | c_mask = c_mask & (offs_d[:, None] < (D - d * BLOCK_D)) 167 | 168 | c = tl.load(c_ptrs, mask=c_mask, other=0.0) 169 | 170 | accum = tl.dot(e, c, accum) 171 | 172 | e_ptrs += BLOCK_D * stride_ed 173 | c_ptrs += BLOCK_D * stride_cd 174 | 175 | tl.debug_barrier() 176 | 177 | if HAS_BIAS: 178 | bias = tl.load(Bias + offs_v * stride_biasv, mask=offs_v < V, other=0.0) 179 | bias = bias.to(dtype=accum.dtype) 180 | accum += bias[None, :] 181 | 182 | if HAS_SOFTCAP: 183 | accum = tl_softcapping(accum, softcap) 184 | 185 | if HAS_VALIDS: 186 | direct_offs_b = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)).to(tl.int64) 187 | lse = tl.load(LSE + direct_offs_b, mask=direct_offs_b < B, other=float("inf")) 188 | else: 189 | lse = tl.load(LSE + offs_b, mask=offs_b < B, other=float("inf")) 190 | 191 | d_accum = tl.exp(accum - lse[:, None]) 192 | d_accum = tl.where(offs_v[None, :] < V, d_accum, 0.0) 193 | 194 | if HAS_TARGETS: 195 | if HAS_SHIFT: 196 | target_offs_b = offs_b + shift 197 | else: 198 | target_offs_b = offs_b 199 | 200 | targets = tl.load(Targets + target_offs_b, mask=target_offs_b < BMax, other=V + 1) 201 | is_target = targets[:, None] == offs_v[None, :] 202 | d_accum += tl.where(is_target, -1.0, 0.0) 203 | else: 204 | is_target = None 205 | 206 | should_skip = False 207 | if (FILTER_E_GRAD and COMPUTE_DE) and (FILTER_C_GRAD and COMPUTE_DC): 208 | if _block_is_filtered(tl.abs(d_accum), filter_eps): 209 | return 210 | elif (FILTER_E_GRAD and COMPUTE_DE) or (FILTER_C_GRAD and COMPUTE_DC): 211 | should_skip = _block_is_filtered(tl.abs(d_accum), filter_eps) 212 | 213 | if HAS_SOFTCAP: 214 | d_accum = tl_softcapping_grad(d_accum, accum, softcap) 215 | 216 | if ITEM_DO: 217 | d_out = tl.load(dOut) 218 | else: 219 | if HAS_SHIFT: 220 | d_out_offs_b = offs_b + shift 221 | else: 222 | d_out_offs_b = offs_b 223 | 224 | d_out = tl.load(dOut + d_out_offs_b, mask=d_out_offs_b < BMax, other=0.0)[:, None] 225 | 226 | d_out = grad_scale * d_out 227 | 228 | d_accum = d_accum * d_out 229 | 230 | if COMPUTE_DBIAS: 231 | tl.atomic_add(dBias + offs_v * stride_biasv, tl.sum(d_accum, 0), mask=offs_v < V) 232 | 233 | d_accum = d_accum.to(e_ptrs.dtype.element_ty) 234 | 235 | if COMPUTE_DE: 236 | if FILTER_E_GRAD: 237 | should_skip_e = should_skip 238 | else: 239 | should_skip_e = False 240 | 241 | if not should_skip_e: 242 | lock_offset = (pid_b // tl.cdiv(B, BLOCK_B * n_de_locks_0)) * n_de_locks_1 243 | 244 | _mm_backward( 245 | d_accum, 246 | dE + (offs_b[:, None] * stride_eb), 247 | dEC + (offs_b[:, None] * stride_eb) if KAHAN_E else None, 248 | offs_b[:, None] < BMax, 249 | dELocks + lock_offset, 250 | n_de_locks_1, 251 | C + offs_v[:, None] * stride_cv, 252 | offs_v[:, None] < V, 253 | stride_ed, 254 | stride_cd, 255 | D, 256 | MM_BACK_BLOCK_D, 257 | MM_BACK_EVEN_D, 258 | KAHAN_E, 259 | ) 260 | 261 | if COMPUTE_DC: 262 | if FILTER_C_GRAD: 263 | should_skip_c = should_skip 264 | else: 265 | should_skip_c = False 266 | 267 | if not should_skip_c: 268 | lock_offset = (pid_v // tl.cdiv(V, BLOCK_V * n_dc_locks_0)) * n_dc_locks_1 269 | 270 | _mm_backward( 271 | tl.trans(d_accum), 272 | dC + (offs_v[:, None] * stride_cv), 273 | dCC + (offs_v[:, None] * stride_cv) if KAHAN_C else None, 274 | offs_v[:, None] < V, 275 | dCLocks + lock_offset, 276 | n_dc_locks_1, 277 | E + (offs_b[:, None] * stride_eb), 278 | offs_b[:, None] < BMax, 279 | stride_cd, 280 | stride_ed, 281 | D, 282 | MM_BACK_BLOCK_D, 283 | MM_BACK_EVEN_D, 284 | KAHAN_C, 285 | ) 286 | 287 | 288 | def _cce_back_block_d(args) -> int: 289 | block_d = args["BLOCK_D"] 290 | return 2 * block_d 291 | 292 | 293 | _cce_backward_kernel = triton.jit(_cce_backward_kernel) 294 | _cce_backward_kernel = triton.heuristics( # type: ignore 295 | { 296 | "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0, 297 | "MM_BACK_BLOCK_D": lambda args: _cce_back_block_d(args), 298 | "MM_BACK_EVEN_D": lambda args: (args["D"] % _cce_back_block_d(args)) == 0, 299 | "HAS_VALIDS": lambda args: args["Valids"] is not None, 300 | "HAS_BIAS": lambda args: args["Bias"] is not None, 301 | "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None, 302 | "HAS_TARGETS": lambda args: args["Targets"] is not None, 303 | "HAS_SOFTCAP": lambda args: args["softcap"] is not None, 304 | "HAS_SHIFT": lambda args: args["shift"] != 0, 305 | "ITEM_DO": lambda args: args["dOut"].numel() == 1, 306 | "GROUP_B": lambda args: 8, 307 | "COMPUTE_DC": lambda args: args["dC"] is not None, 308 | "COMPUTE_DE": lambda args: args["dE"] is not None, 309 | "KAHAN_E": lambda args: args["dEC"] is not None, 310 | "KAHAN_C": lambda args: args["dCC"] is not None, 311 | "COMPUTE_DBIAS": lambda args: args["dBias"] is not None, 312 | } 313 | )(_cce_backward_kernel) 314 | _cce_backward_kernel = cce_backward_autotune()(_cce_backward_kernel) # type: ignore 315 | 316 | 317 | def cce_backward_kernel( 318 | do: torch.Tensor, 319 | e: torch.Tensor, 320 | c: torch.Tensor, 321 | bias: torch.Tensor | None, 322 | lse: torch.Tensor, 323 | valids: torch.Tensor | None, 324 | softcap: float | None, 325 | filter_eps: float | None, 326 | targets: torch.Tensor | None = None, 327 | shift: int = 0, 328 | vocab_ordering: torch.Tensor | None = None, 329 | grad_scale: float = 1.0, 330 | accum_e_fp32: bool = False, 331 | accum_c_fp32: bool = False, 332 | filter_e_grad: bool = True, 333 | filter_c_grad: bool = True, 334 | reduce_e_grad: bool = False, 335 | pg: torch.distributed.ProcessGroup | None = None, 336 | ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: 337 | assert do.numel() in (e.size(0), 1) 338 | assert c.size(1) == e.size(1) 339 | assert lse.size(0) == e.size(0) or (valids is not None and lse.size(0) == valids.size(0)) 340 | assert e.dtype in ( 341 | torch.float16, 342 | torch.bfloat16, 343 | ), "Backwards requires embeddings to be bf16 or fp16" 344 | assert c.dtype in ( 345 | torch.float16, 346 | torch.bfloat16, 347 | ), "Backwards requires classifier to be bf16 or fp16" 348 | 349 | do = do.contiguous() 350 | lse = lse.contiguous() 351 | 352 | can_use_fp32_accum = is_triton_greater_or_equal_3_2_0() 353 | 354 | de_dtype = torch.float32 if (accum_e_fp32 and can_use_fp32_accum) else None 355 | de = torch.zeros_like(e, dtype=de_dtype) if e.requires_grad else None 356 | 357 | dc_dtype = torch.float32 if (accum_c_fp32 and can_use_fp32_accum) else None 358 | dc = torch.zeros_like(c, dtype=dc_dtype) if c.requires_grad else None 359 | 360 | accum_e_fp32 = accum_e_fp32 and de is not None 361 | accum_c_fp32 = accum_c_fp32 and dc is not None 362 | 363 | if bias is not None: 364 | dbias = torch.zeros_like(bias, dtype=torch.float32) if bias.requires_grad else None 365 | else: 366 | dbias = None 367 | 368 | if de is not None: 369 | assert de.stride() == e.stride() 370 | 371 | if dc is not None: 372 | assert dc.stride() == c.stride() 373 | 374 | if dbias is not None: 375 | assert bias is not None 376 | assert dbias.stride() == bias.stride() 377 | 378 | if accum_e_fp32 and not can_use_fp32_accum: 379 | dec = torch.zeros_like(e) if de is not None else None 380 | else: 381 | dec = None 382 | 383 | if accum_c_fp32 and not can_use_fp32_accum: 384 | dcc = torch.zeros_like(c) if dc is not None else None 385 | else: 386 | dcc = None 387 | 388 | if dec is not None: 389 | assert dec.stride() == e.stride() 390 | 391 | if dcc is not None: 392 | assert dcc.stride() == e.stride() 393 | 394 | if valids is not None: 395 | assert valids.ndim == 1 396 | B = valids.size(0) 397 | else: 398 | B = e.size(0) 399 | 400 | if do.numel() > 1: 401 | do = do.contiguous() 402 | lse = lse.contiguous() 403 | assert do.stride(0) == lse.stride(0), f"{do.stride()=}, {lse.stride()=}" 404 | 405 | def grid(META): 406 | return (triton.cdiv(B, META["BLOCK_B"]) * triton.cdiv(c.size(0), META["BLOCK_V"]),) 407 | 408 | if vocab_ordering is not None: 409 | assert vocab_ordering.ndim == 1 410 | assert vocab_ordering.numel() == c.size(0) 411 | assert vocab_ordering.stride(0) == 1 412 | 413 | nd_locks = triton.cdiv(c.size(1), 64) 414 | if de is not None: 415 | de_locks = e.new_zeros((triton.cdiv(B, 128), nd_locks), dtype=torch.int32) 416 | de_lock_sizes = de_locks.size() 417 | else: 418 | de_locks = None 419 | de_lock_sizes = (None, None) 420 | 421 | if dc is not None: 422 | dc_locks = c.new_zeros((triton.cdiv(c.size(0), 128), nd_locks), dtype=torch.int32) 423 | dc_lock_sizes = dc_locks.size() 424 | else: 425 | dc_locks = None 426 | dc_lock_sizes = (None, None) 427 | 428 | _cce_backward_kernel[grid]( 429 | e, 430 | c, 431 | bias, 432 | lse, 433 | do, 434 | grad_scale, 435 | valids, 436 | vocab_ordering, 437 | softcap, 438 | targets, 439 | de, 440 | dec, 441 | de_locks, 442 | dc, 443 | dcc, 444 | dc_locks, 445 | dbias, 446 | B, 447 | e.size(1), 448 | c.size(0), 449 | e.size(0), 450 | *de_lock_sizes, 451 | *dc_lock_sizes, 452 | e.stride(0), 453 | e.stride(1), 454 | c.stride(0), 455 | c.stride(1), 456 | 1 if bias is None else bias.stride(0), 457 | 1 if valids is None else valids.stride(0), 458 | filter_eps, 459 | shift=shift, 460 | B_BIN=b_bin_fn(B), 461 | FILTER_E_GRAD=filter_e_grad and de is not None, 462 | FILTER_C_GRAD=filter_c_grad and dc is not None, 463 | ) 464 | 465 | if reduce_e_grad and de is not None: 466 | de = vp_reduce_e_grad(de, pg) 467 | 468 | if dbias is not None: 469 | assert bias is not None 470 | dbias = dbias.to(dtype=bias.dtype) 471 | 472 | if dc is not None: 473 | dc = dc.to(dtype=c.dtype) 474 | 475 | if de is not None: 476 | de = de.to(dtype=e.dtype) 477 | 478 | return de, dc, dbias 479 | -------------------------------------------------------------------------------- /cut_cross_entropy/cce_lse_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from typing import Literal, overload 3 | 4 | import torch 5 | import triton 6 | import triton.language as tl 7 | 8 | from cut_cross_entropy.tl_autotune import cce_forward_autotune 9 | from cut_cross_entropy.tl_utils import b_bin_fn, tl_logaddexp, tl_softcapping 10 | 11 | 12 | def _cce_lse_forward_kernel( 13 | E, 14 | C, 15 | Bias, 16 | LSE, 17 | LA, 18 | Locks, 19 | Valids, 20 | softcap, 21 | B, 22 | V, 23 | D, 24 | BMax, 25 | stride_eb, 26 | stride_ed, 27 | stride_cv, 28 | stride_cd, 29 | stride_biasv, 30 | stride_lse_b, 31 | stride_vb, 32 | num_locks, 33 | # Meta-parameters 34 | B_BIN, 35 | HAS_BIAS: tl.constexpr, 36 | HAS_VALIDS: tl.constexpr, 37 | BLOCK_B: tl.constexpr, 38 | BLOCK_V: tl.constexpr, 39 | BLOCK_D: tl.constexpr, # 40 | GROUP_B: tl.constexpr, # 41 | EVEN_D: tl.constexpr, 42 | HAS_SOFTCAP: tl.constexpr, 43 | HAS_LA: tl.constexpr, 44 | DOT_PRECISION: tl.constexpr, 45 | ): 46 | pid = tl.program_id(axis=0) 47 | num_pid_b = tl.cdiv(B, BLOCK_B) 48 | num_pid_v = tl.cdiv(V, BLOCK_V) 49 | num_pid_in_group = GROUP_B * num_pid_v 50 | group_id = pid // num_pid_in_group 51 | first_pid_b = group_id * GROUP_B 52 | group_size_b = min(num_pid_b - first_pid_b, GROUP_B) 53 | pid_b = (first_pid_b + ((pid % num_pid_in_group) % group_size_b)).to(tl.int64) 54 | pid_v = ((pid % num_pid_in_group) // group_size_b).to(tl.int64) 55 | 56 | offs_b = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)).to(tl.int64) 57 | if HAS_VALIDS: 58 | offs_b = tl.load(Valids + stride_vb * offs_b, mask=offs_b < B, other=BMax).to(tl.int64) 59 | 60 | offs_v = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)).to(tl.int64) 61 | offs_d = tl.arange(0, BLOCK_D).to(tl.int64) 62 | e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) 63 | c_ptrs = C + (offs_v[None, :] * stride_cv + offs_d[:, None] * stride_cd) 64 | 65 | accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32) 66 | for d in range(0, tl.cdiv(D, BLOCK_D)): 67 | e_mask = offs_b[:, None] < BMax 68 | if not EVEN_D: 69 | e_mask = e_mask & (offs_d[None, :] < (D - d * BLOCK_D)) 70 | 71 | e = tl.load(e_ptrs, mask=e_mask, other=0.0) 72 | 73 | c_mask = offs_v[None, :] < V 74 | if not EVEN_D: 75 | c_mask = c_mask & (offs_d[:, None] < (D - d * BLOCK_D)) 76 | 77 | c = tl.load(c_ptrs, mask=c_mask, other=0.0) 78 | 79 | accum = tl.dot(e, c, accum, input_precision=DOT_PRECISION) 80 | 81 | e_ptrs += BLOCK_D * stride_ed 82 | c_ptrs += BLOCK_D * stride_cd 83 | 84 | tl.debug_barrier() 85 | 86 | if HAS_BIAS: 87 | bias = tl.load(Bias + offs_v * stride_biasv, mask=offs_v < V, other=0.0) 88 | bias = bias.to(dtype=accum.dtype) 89 | accum += bias[None, :] 90 | 91 | logits = tl.where(offs_v[None, :] < V, accum, -float("inf")) 92 | if HAS_SOFTCAP: 93 | logits = tl_softcapping(logits, softcap) 94 | 95 | if HAS_LA: 96 | this_avg_logit = tl.sum(logits, 0) / B 97 | tl.atomic_add(LA + offs_v, this_avg_logit, mask=offs_v < V) 98 | 99 | this_mx = tl.max(logits, axis=1) 100 | e = tl.exp(logits - this_mx[:, None]) 101 | this_lse = this_mx + tl.log(tl.sum(e, axis=1)) 102 | 103 | offs_b = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)).to(tl.int64) 104 | o_mask = offs_b < B 105 | 106 | lse_ptrs = LSE + (stride_lse_b * offs_b) 107 | 108 | this_locks = Locks + (pid_b // tl.cdiv(B, BLOCK_B * num_locks)) 109 | while tl.atomic_cas(this_locks, 0, 1) == 1: 110 | pass 111 | 112 | lse = tl.load(lse_ptrs, mask=o_mask, other=0.0, eviction_policy="evict_last") 113 | lse = tl_logaddexp(lse, this_lse) 114 | tl.store(lse_ptrs, lse, mask=o_mask, eviction_policy="evict_last") 115 | 116 | tl.debug_barrier() 117 | tl.atomic_xchg(this_locks, 0) 118 | 119 | 120 | _cce_lse_forward_kernel = triton.jit(_cce_lse_forward_kernel) 121 | _cce_lse_forward_kernel = triton.heuristics( # type: ignore 122 | { 123 | "EVEN_D": lambda args: args["D"] % args["BLOCK_D"] == 0, 124 | "HAS_BIAS": lambda args: args["Bias"] is not None, 125 | "HAS_VALIDS": lambda args: args["Valids"] is not None, 126 | "HAS_SOFTCAP": lambda args: args["softcap"] is not None, 127 | "HAS_LA": lambda args: args["LA"] is not None, 128 | "GROUP_B": lambda args: 8, 129 | "DOT_PRECISION": lambda args: "tf32" 130 | if torch.get_float32_matmul_precision() == "high" 131 | else "ieee", 132 | } 133 | )(_cce_lse_forward_kernel) 134 | _cce_lse_forward_kernel = cce_forward_autotune()(_cce_lse_forward_kernel) # type: ignore 135 | 136 | 137 | @overload 138 | def cce_lse_forward_kernel( 139 | e: torch.Tensor, 140 | c: torch.Tensor, 141 | bias: torch.Tensor | None = None, 142 | valids: torch.Tensor | None = None, 143 | softcap: float | None = None, 144 | return_logit_avg: Literal[False] = False, 145 | ) -> torch.Tensor: ... 146 | 147 | 148 | @overload 149 | def cce_lse_forward_kernel( 150 | e: torch.Tensor, 151 | c: torch.Tensor, 152 | bias: torch.Tensor | None = None, 153 | valids: torch.Tensor | None = None, 154 | softcap: float | None = None, 155 | return_logit_avg: Literal[True] = True, 156 | ) -> tuple[torch.Tensor, torch.Tensor]: ... 157 | 158 | 159 | @overload 160 | def cce_lse_forward_kernel( 161 | e: torch.Tensor, 162 | c: torch.Tensor, 163 | bias: torch.Tensor | None = None, 164 | valids: torch.Tensor | None = None, 165 | softcap: float | None = None, 166 | return_logit_avg: bool = False, 167 | ) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor: ... 168 | 169 | 170 | def cce_lse_forward_kernel( 171 | e: torch.Tensor, 172 | c: torch.Tensor, 173 | bias: torch.Tensor | None = None, 174 | valids: torch.Tensor | None = None, 175 | softcap: float | None = None, 176 | return_logit_avg: bool = False, 177 | ) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor: 178 | # Check constraints. 179 | assert e.shape[1] == c.shape[1], "Incompatible dimensions" 180 | assert e.is_contiguous(), "Matrix A must be contiguous" 181 | if valids is not None: 182 | assert valids.ndim == 1 183 | B = valids.numel() 184 | else: 185 | B, _ = e.shape 186 | 187 | if bias is not None: 188 | assert bias.ndim == 1 189 | assert c.shape[0] == bias.shape[0] 190 | 191 | V, D = c.shape 192 | # Allocates output. 193 | lse = e.new_full((B,), -float("inf"), dtype=torch.float32) 194 | locks = e.new_full( 195 | (triton.cdiv(B, 128),), 196 | 0, 197 | dtype=torch.uint32, 198 | ) 199 | 200 | if return_logit_avg: 201 | logit_avg = e.new_full((V,), 0.0, dtype=torch.float32) 202 | else: 203 | logit_avg = None 204 | 205 | # 1D launch kernel where each block gets its own program. 206 | def grid(META) -> tuple[int]: 207 | return (triton.cdiv(B, META["BLOCK_B"]) * triton.cdiv(V, META["BLOCK_V"]),) 208 | 209 | _cce_lse_forward_kernel[grid]( 210 | e, 211 | c, 212 | bias, 213 | lse, # 214 | logit_avg, 215 | locks, 216 | valids, 217 | softcap, 218 | B, 219 | V, 220 | D, # 221 | e.size(0), 222 | e.stride(0), 223 | e.stride(1), # 224 | c.stride(0), 225 | c.stride(1), # 226 | 1 if bias is None else bias.stride(0), 227 | lse.stride(0), 228 | 1 if valids is None else valids.stride(0), 229 | num_locks=locks.size(0), 230 | B_BIN=b_bin_fn(B), 231 | ) 232 | 233 | if return_logit_avg: 234 | assert logit_avg is not None 235 | return lse, logit_avg 236 | else: 237 | return lse 238 | -------------------------------------------------------------------------------- /cut_cross_entropy/cce_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | 3 | import enum 4 | from enum import auto 5 | from typing import TypedDict 6 | 7 | 8 | class LinearCrossEntropyImpl(enum.IntEnum): 9 | CCE = auto() 10 | TORCH_COMPILE = auto() 11 | CCE_KAHAN_FULL_C = auto() 12 | CCE_KAHAN_FULL_E = auto() 13 | 14 | CCE_EXACT = auto() 15 | CCE_KAHAN_FULL_C_FULL_E = auto() 16 | CCE_KAHAN_FULL = auto() 17 | 18 | 19 | class CCEPreset(TypedDict): 20 | filter_eps: float | str | None 21 | accum_e_fp32: bool 22 | accum_c_fp32: bool 23 | filter_e_grad: bool 24 | filter_c_grad: bool 25 | 26 | 27 | class CCEPresets: 28 | names: set[str] = set( 29 | v.name.lower() for v in LinearCrossEntropyImpl if v.name.lower() != "torch_compile" 30 | ) 31 | 32 | @classmethod 33 | def build_for_impl(cls, impl: str, opts: CCEPreset) -> CCEPreset: 34 | if impl not in cls.names: 35 | raise ValueError(f"{impl!r} not in {cls.names}") 36 | 37 | if impl == "cce": 38 | return opts 39 | 40 | opts = opts.copy() 41 | if impl in ("cce_exact", "cce_kahan_full", "cce_kahan_full_c_full_e"): 42 | opts["filter_eps"] = None 43 | opts["accum_e_fp32"] = True 44 | opts["accum_c_fp32"] = True 45 | 46 | return opts 47 | 48 | if impl == "cce_kahan_full_c": 49 | opts["filter_eps"] = "auto" 50 | 51 | opts["accum_c_fp32"] = True 52 | opts["filter_c_grad"] = False 53 | 54 | opts["accum_e_fp32"] = True 55 | opts["filter_e_grad"] = True 56 | 57 | return opts 58 | 59 | if impl == "cce_kahan_full_e": 60 | opts["filter_eps"] = "auto" 61 | 62 | opts["accum_c_fp32"] = True 63 | opts["filter_c_grad"] = True 64 | 65 | opts["accum_e_fp32"] = True 66 | opts["filter_e_grad"] = False 67 | 68 | return opts 69 | 70 | raise NotImplementedError(f"{impl=}") 71 | -------------------------------------------------------------------------------- /cut_cross_entropy/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | IGNORE_INDEX: int = -100 3 | -------------------------------------------------------------------------------- /cut_cross_entropy/doc.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl 3 | 4 | LINEAR_CROSS_ENTROPY_DOC = """Computes cross-entropy loss using the logits generated by performing 5 | the matrix multiplication between the embeddings (e) and classifier (c). 6 | 7 | This method saves GPU memory by not materializing the logits into GPU 8 | main memory. 9 | 10 | 11 | Specifically, this computes 12 | 13 | ```python 14 | 15 | loss = F.cross_entropy((e @ c.T).float(), targets) 16 | ``` 17 | 18 | without allocating the intermediary (e @ c.T).float() matrix. 19 | 20 | :param e: Embedding of the inputs used to compute the logits. Shape (..., D) 21 | :param c: Classifier matrix. Shape (NumClasses, D) 22 | :param targets: The target class for each input. Values must be in [0, NumClasses). Shape (...) 23 | :param ignore_index: If an input as a target of this value, it is ignored in the loss computation. 24 | :param softcap: The value for logit softcapping. 25 | :param reduction: The reduction to perform over the loss. Supports "mean", "sum", and "none". 26 | :param shift: When non-zero, the embedding and targets will be shifted along the temporal axis to perform nth-next token prediction. 27 | Specifically, this is used to efficiently compute the following 28 | 29 | ```python 30 | shift_e = e[..., :-shift, :].flatten(0, -2) 31 | shift_targets = targets[..., shift:].flatten() 32 | 33 | loss = F.cross_entropy((shift_e @ c.T).float(), targets) 34 | ``` 35 | 36 | If given a boolean value, False will be treated as zero and True will be treated as one. 37 | 38 | When this value is non-zero or True, e and targets must have shape (..., T, D) and (..., T), respectively. 39 | 40 | Integer values must be in [0, T) 41 | """ 42 | 43 | CCE_OPTS_DOC = [ 44 | """ 45 | :param filter_eps: The threshold value used to determine which locations can be safely ignored 46 | in gradient computation. The default value of "auto" will automatically choose a value 47 | based on the input dtype.""", 48 | """ 49 | :param accum_e_fp32: Whether or not to use fp32 accumulation for dE (use Kahan summation for Triton < 3.2 to work around a bug). 50 | This is useful when working with models with a very large vocabulary or very long sequence lengths.""", 51 | """:param accum_c_fp32: Whether or not to use fp32 accumulation for dC (use Kahan summation for Triton < 3.2 to work around a bug). 52 | This is useful when working with models with a very large vocabulary or very long sequence lengths.""", 53 | """ 54 | :param filter_e_grad: Whether or not to apply gradient filter to the embedding gradient (dE). If filter_eps is None, this 55 | will be set to False.""", 56 | """ 57 | :param filter_c_grad: Whether or not to apply gradient filter to the classifier gradient (dC). If filter_eps is None, this 58 | will be set to False.""", 59 | ] 60 | 61 | IMPL_DOC = """ 62 | :param impl: The implementation to use. Can be one of 63 | {impls}. 64 | If using one of the 'cce' implementations that is not 'cce', 65 | the cce parameters will be set automatically and any use 66 | provided values be ignored. 67 | """.format(impls=set(v.name.lower() for v in LinearCrossEntropyImpl)) 68 | 69 | 70 | def add_doc_start(*docstr: str): 71 | def add_doc(fn): 72 | fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") 73 | 74 | return fn 75 | 76 | return add_doc 77 | -------------------------------------------------------------------------------- /cut_cross_entropy/indexed_dot.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import torch 3 | import triton 4 | import triton.language as tl 5 | 6 | from cut_cross_entropy.tl_autotune import indexed_dot_autotune 7 | from cut_cross_entropy.tl_utils import b_bin_fn 8 | from cut_cross_entropy.utils import softcapping 9 | 10 | 11 | def _indexed_neg_dot_forward_kernel( 12 | E, 13 | C, 14 | Inds, 15 | Bias, 16 | Valids, 17 | Out, 18 | B, 19 | D, 20 | V, 21 | BMax, 22 | stride_eb, 23 | stride_ed, 24 | stride_cv, 25 | stride_cd, 26 | stride_ib, 27 | stride_biasv, 28 | stride_vb, 29 | shift, 30 | B_BIN, 31 | BLOCK_B: tl.constexpr, 32 | BLOCK_D: tl.constexpr, 33 | GROUP_B: tl.constexpr, 34 | HAS_BIAS: tl.constexpr, 35 | HAS_VALIDS: tl.constexpr, 36 | EVEN_D: tl.constexpr, 37 | HAS_SHIFT: tl.constexpr, 38 | ): 39 | pid = tl.program_id(axis=0) 40 | num_b_chunks = tl.cdiv(B, BLOCK_B) 41 | num_d_chunks = tl.cdiv(D, BLOCK_D) 42 | num_d_in_group = GROUP_B * num_d_chunks 43 | group_id = pid // num_d_in_group 44 | first_pid_b = group_id * GROUP_B 45 | group_size_b = min(num_b_chunks - first_pid_b, GROUP_B) 46 | pid_b = (first_pid_b + ((pid % num_d_in_group) % group_size_b)).to(tl.int64) 47 | pid_d = ((pid % num_d_in_group) // group_size_b).to(tl.int64) 48 | 49 | offs_b = (tl.arange(0, BLOCK_B) + pid_b * BLOCK_B).to(tl.int64) 50 | if HAS_VALIDS: 51 | offs_b = tl.load(Valids + stride_vb * offs_b, mask=offs_b < B, other=BMax).to(tl.int64) 52 | 53 | offs_d = (tl.arange(0, BLOCK_D) + pid_d * BLOCK_D).to(tl.int64) 54 | e_ptrs = E + (stride_eb * offs_b[:, None] + stride_ed * offs_d[None, :]) 55 | 56 | e_mask = offs_b[:, None] < BMax 57 | if not EVEN_D: 58 | e_mask = e_mask & (offs_d[None, :] < D) 59 | 60 | e = tl.load(e_ptrs, mask=e_mask, other=0.0) 61 | 62 | if HAS_SHIFT: 63 | offs_b = offs_b + shift 64 | 65 | inds = tl.load(Inds + stride_ib * offs_b, mask=offs_b < BMax, other=V) 66 | 67 | c_ptrs = C + (inds[:, None] * stride_cv + offs_d[None, :] * stride_cd) 68 | 69 | c_mask = inds[:, None] < V 70 | if not EVEN_D: 71 | c_mask = c_mask & (offs_d[None, :] < D) 72 | 73 | c = tl.load(c_ptrs, mask=c_mask, other=0.0) 74 | 75 | dot = e.to(tl.float32) * c.to(tl.float32) 76 | neg_dot = -tl.sum(dot, 1) 77 | 78 | if HAS_BIAS: 79 | bias = tl.load(Bias + inds * stride_biasv, mask=inds < V, other=0.0) 80 | bias = bias.to(tl.float32) 81 | neg_dot -= bias 82 | 83 | offs_b = (tl.arange(0, BLOCK_B) + pid_b * BLOCK_B).to(tl.int64) 84 | out_ptrs = Out + offs_b 85 | tl.atomic_add(out_ptrs, neg_dot.to(out_ptrs.dtype.element_ty), mask=offs_b < B) 86 | 87 | 88 | _indexed_neg_dot_forward_kernel = triton.jit(_indexed_neg_dot_forward_kernel) 89 | _indexed_neg_dot_forward_kernel = triton.heuristics( # type: ignore 90 | { 91 | "EVEN_D": lambda args: args["D"] % args["BLOCK_D"] == 0, 92 | "HAS_BIAS": lambda args: args["Bias"] is not None, 93 | "HAS_VALIDS": lambda args: args["Valids"] is not None, 94 | "HAS_SHIFT": lambda args: args["shift"] != 0, 95 | "GROUP_B": lambda args: 8, 96 | } 97 | )(_indexed_neg_dot_forward_kernel) 98 | _indexed_neg_dot_forward_kernel = indexed_dot_autotune()(_indexed_neg_dot_forward_kernel) # type: ignore 99 | 100 | 101 | def indexed_neg_dot_forward_kernel( 102 | e: torch.Tensor, 103 | c: torch.Tensor, 104 | inds: torch.Tensor, 105 | bias: torch.Tensor | None = None, 106 | shift: int = 0, 107 | valids: torch.Tensor | None = None, 108 | softcap: float | None = None, 109 | out_dtype: torch.dtype | None = None, 110 | ) -> torch.Tensor: 111 | assert inds.ndim == 1 112 | assert e.ndim == 2 113 | assert c.ndim == 2 114 | assert inds.size(0) == e.size(0) 115 | assert c.size(1) == e.size(1) 116 | 117 | if valids is not None: 118 | assert valids.ndim == 1 119 | B = valids.size(0) 120 | else: 121 | B = e.size(0) 122 | 123 | out = e.new_zeros((B,), dtype=torch.float32) 124 | 125 | def grid(META) -> tuple[int]: 126 | return (triton.cdiv(B, META["BLOCK_B"]) * triton.cdiv(e.size(1), META["BLOCK_D"]),) 127 | 128 | _indexed_neg_dot_forward_kernel[grid]( 129 | e, 130 | c, 131 | inds, 132 | bias, 133 | valids, 134 | out, 135 | B, 136 | e.size(1), 137 | c.size(0), 138 | e.size(0), 139 | e.stride(0), 140 | e.stride(1), 141 | c.stride(0), 142 | c.stride(1), 143 | inds.stride(0), 144 | 1 if bias is None else bias.stride(0), 145 | 1 if valids is None else valids.stride(0), 146 | shift=shift, 147 | B_BIN=b_bin_fn(B), 148 | ) 149 | 150 | if softcap is not None: 151 | out = softcapping(out, softcap) 152 | 153 | if out_dtype is None: 154 | out_dtype = e.dtype 155 | 156 | out = out.to(out_dtype) 157 | 158 | return out 159 | -------------------------------------------------------------------------------- /cut_cross_entropy/linear_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import platform 3 | from typing import TYPE_CHECKING 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from cut_cross_entropy.cce_utils import CCEPreset, CCEPresets, LinearCrossEntropyImpl 9 | from cut_cross_entropy.constants import IGNORE_INDEX 10 | from cut_cross_entropy.doc import CCE_OPTS_DOC, IMPL_DOC, LINEAR_CROSS_ENTROPY_DOC, add_doc_start 11 | from cut_cross_entropy.torch_compile import torch_compile_linear_cross_entropy 12 | from cut_cross_entropy.utils import is_torch_greater_or_equal_2_5 13 | from cut_cross_entropy.vocab_parallel import VocabParallelOptions 14 | 15 | PLATFORM_SYSTEM = platform.system() 16 | 17 | if TYPE_CHECKING or PLATFORM_SYSTEM != "Darwin": 18 | from cut_cross_entropy.cce import cce_linear_cross_entropy 19 | 20 | LCE_IMPL_DEFAULT = LinearCrossEntropyImpl.CCE 21 | else: 22 | cce_linear_cross_entropy = None 23 | LCE_IMPL_DEFAULT = LinearCrossEntropyImpl.TORCH_COMPILE 24 | 25 | if TYPE_CHECKING or is_torch_greater_or_equal_2_5(): 26 | import torch.distributed.tensor 27 | 28 | 29 | is_d_tensor_error_message = ( 30 | "Received {name} as a torch.distributed.tensor.DTensor. " 31 | "This is not supported. " 32 | "If possible, change the sharding strategy such that {name} is already unsharded. " 33 | "If not, see https://github.com/apple/ml-cross-entropy/issues/31." 34 | ) 35 | 36 | 37 | @add_doc_start(LINEAR_CROSS_ENTROPY_DOC) 38 | @add_doc_start(*(doc_str + " Only valid for the cce implementation.\n" for doc_str in CCE_OPTS_DOC)) 39 | @add_doc_start(IMPL_DOC) 40 | def linear_cross_entropy( 41 | e: torch.Tensor, 42 | c: torch.Tensor, 43 | targets: torch.Tensor, 44 | bias: torch.Tensor | None = None, 45 | ignore_index: int = IGNORE_INDEX, 46 | softcap: float | None = None, 47 | reduction: str = "mean", 48 | shift: bool | int = 0, 49 | filter_eps: float | str | None = "auto", 50 | accum_e_fp32: bool = False, 51 | accum_c_fp32: bool = False, 52 | filter_e_grad: bool = True, 53 | filter_c_grad: bool = True, 54 | impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT, 55 | vocab_parallel_options: VocabParallelOptions | None = None, 56 | ) -> torch.Tensor: 57 | """ 58 | :param impl: The linear cross entropy implementation to use. Currently supports cce, torch_compile, and cce_exact. 59 | """ 60 | 61 | if is_torch_greater_or_equal_2_5(): 62 | maybe_tensor_inputs = dict(e=e, c=c, targets=targets, bias=bias) 63 | for k, v in maybe_tensor_inputs.items(): 64 | if isinstance(v, torch.distributed.tensor.DTensor): 65 | raise ValueError(is_d_tensor_error_message.format(name=k)) 66 | 67 | if isinstance(impl, LinearCrossEntropyImpl): 68 | impl = impl.name.lower() 69 | 70 | if isinstance(shift, int) and (shift < 0 or shift >= targets.size(-1)): 71 | raise ValueError(f"Shift must be in the range [0, {targets.size(-1)}). Got {shift}.") 72 | 73 | if vocab_parallel_options is not None: 74 | expected_v_dim_size = vocab_parallel_options.stop - vocab_parallel_options.start 75 | if c.size(0) != expected_v_dim_size: 76 | raise ValueError(f"Expected c.size(0) to be {expected_v_dim_size}, got {c.size(0)}.") 77 | 78 | if bias is not None and bias.size(0) != c.size(0): 79 | raise ValueError( 80 | f"Bias has a different number of elements than c. {bias.size(0)} vs. {c.size(0)}." 81 | ) 82 | 83 | if impl in CCEPresets.names: 84 | if platform.system() == "Darwin": 85 | raise RuntimeError( 86 | "CCE does not support MacOS. Please use torch_compile when running on MacOS instead." 87 | ) 88 | 89 | cce_opts = CCEPresets.build_for_impl( 90 | impl, 91 | CCEPreset( 92 | filter_eps=filter_eps, 93 | accum_e_fp32=accum_e_fp32, 94 | accum_c_fp32=accum_c_fp32, 95 | filter_e_grad=filter_e_grad, 96 | filter_c_grad=filter_c_grad, 97 | ), 98 | ) 99 | 100 | assert cce_linear_cross_entropy is not None 101 | return cce_linear_cross_entropy( 102 | e, 103 | c, 104 | targets, 105 | bias, 106 | ignore_index, 107 | softcap, 108 | reduction, 109 | shift, 110 | **cce_opts, 111 | vocab_parallel_options=vocab_parallel_options, 112 | ) 113 | elif impl == "torch_compile": 114 | return torch_compile_linear_cross_entropy( 115 | e, 116 | c, 117 | targets, 118 | bias, 119 | ignore_index, 120 | softcap, 121 | reduction, 122 | shift, 123 | vocab_parallel_options=vocab_parallel_options, 124 | ) 125 | else: 126 | raise NotImplementedError(f"{impl} is not implemented.") 127 | 128 | 129 | class LinearCrossEntropy(nn.Module): 130 | def __init__( 131 | self, 132 | ignore_index: int = IGNORE_INDEX, 133 | softcap: float | None = None, 134 | reduction: str = "mean", 135 | shift: bool | int = 0, 136 | filter_eps: float | str | None = "auto", 137 | accum_e_fp32: bool = False, 138 | accum_c_fp32: bool = False, 139 | filter_e_grad: bool = True, 140 | filter_c_grad: bool = True, 141 | impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT, 142 | ): 143 | super().__init__() 144 | self.ignore_index = ignore_index 145 | self.softcap = softcap 146 | self.reduction = reduction 147 | self.filter_eps = filter_eps 148 | self.shift = shift 149 | 150 | self.accum_e_fp32 = accum_e_fp32 151 | self.accum_c_fp32 = accum_c_fp32 152 | 153 | self.filter_e_grad = filter_e_grad 154 | self.filter_c_grad = filter_c_grad 155 | 156 | self.impl = impl 157 | 158 | def forward( 159 | self, 160 | e: torch.Tensor, 161 | c: torch.Tensor, 162 | targets: torch.Tensor, 163 | bias: torch.Tensor | None = None, 164 | ) -> torch.Tensor: 165 | return linear_cross_entropy( 166 | e, 167 | c, 168 | targets, 169 | bias=bias, 170 | ignore_index=self.ignore_index, 171 | softcap=self.softcap, 172 | reduction=self.reduction, 173 | shift=self.shift, 174 | filter_eps=self.filter_eps, 175 | accum_e_fp32=self.accum_e_fp32, 176 | accum_c_fp32=self.accum_c_fp32, 177 | filter_e_grad=self.filter_e_grad, 178 | filter_c_grad=self.filter_c_grad, 179 | impl=self.impl, 180 | ) 181 | -------------------------------------------------------------------------------- /cut_cross_entropy/tl_autotune.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import functools 3 | import heapq 4 | import os 5 | from dataclasses import dataclass, field 6 | from typing import Any, Callable 7 | 8 | import torch 9 | import triton 10 | from triton import Config, cdiv 11 | from triton.runtime import autotuner, driver 12 | from triton.testing import ( 13 | get_dram_gbps, 14 | get_max_simd_tflops, 15 | get_max_tensorcore_tflops, 16 | nvsmi, 17 | ) 18 | 19 | from cut_cross_entropy.tl_utils import is_triton_greater_or_equal_3_2_0 20 | 21 | _AUTOTUNE: bool = os.getenv("CCE_AUTOTUNE", "0") != "0" 22 | 23 | 24 | @dataclass 25 | class NoneSupportRestorer: 26 | reset_idx_or_name: list[int | str] 27 | restore_idx_or_name: list[int | str] 28 | _restore_copies: dict[str | int, torch.Tensor | None] = field(default_factory=dict, init=False) 29 | 30 | def pre_hook( 31 | self, 32 | args: list[torch.Tensor | None | Any] | dict[str, torch.Tensor | None | Any], 33 | reset_only: bool = False, 34 | ) -> None: 35 | for i in self.reset_idx_or_name: 36 | if isinstance(i, str): 37 | assert isinstance(args, dict) 38 | v = args[i] 39 | else: 40 | assert isinstance(args, list) 41 | v = args[i] 42 | 43 | if v is not None: 44 | assert isinstance(v, torch.Tensor) 45 | v.zero_() 46 | 47 | if not reset_only: 48 | for i in self.restore_idx_or_name: 49 | if isinstance(i, str): 50 | assert isinstance(args, dict) 51 | v = args[i] 52 | else: 53 | assert isinstance(args, list) 54 | v = args[i] 55 | 56 | if v is not None: 57 | assert isinstance(v, torch.Tensor) 58 | self._restore_copies[i] = v.clone() 59 | else: 60 | self._restore_copies[i] = None 61 | 62 | def post_hook( 63 | self, 64 | args: list[torch.Tensor | None | Any] | dict[str, torch.Tensor | None | Any], 65 | exception=None, 66 | ) -> None: 67 | for i, old_v in self._restore_copies.items(): 68 | if isinstance(i, str): 69 | assert isinstance(args, dict) 70 | v = args[i] 71 | else: 72 | assert isinstance(args, list) 73 | v = args[i] 74 | 75 | if v is not None: 76 | assert isinstance(v, torch.Tensor) 77 | assert old_v is not None 78 | 79 | v.copy_(old_v) 80 | 81 | self._restore_copies = {} 82 | 83 | 84 | @functools.wraps(triton.autotune) 85 | def _cce_autotune(*args, **kwargs) -> Callable[..., autotuner.Autotuner]: 86 | def decorator(fn): 87 | reset_idx_or_name = [] 88 | restore_idx_or_name = [] 89 | arg_names = fn.arg_names 90 | reset_idx_or_name = kwargs.pop("reset_to_zero", []) 91 | if not is_triton_greater_or_equal_3_2_0(): 92 | reset_idx_or_name = [arg_names.index(k) for k in restore_idx_or_name] 93 | 94 | restore_idx_or_name = kwargs.pop("restore_value", []) 95 | if not is_triton_greater_or_equal_3_2_0(): 96 | restore_idx_or_name = [arg_names.index(k) for k in restore_idx_or_name] 97 | 98 | restorer = NoneSupportRestorer(reset_idx_or_name, restore_idx_or_name) 99 | if len(reset_idx_or_name) > 0: 100 | kwargs["pre_hook"] = restorer.pre_hook 101 | 102 | if len(restore_idx_or_name) > 0: 103 | kwargs["post_hook"] = restorer.post_hook 104 | 105 | return triton.autotune(*args, **kwargs)(fn) 106 | 107 | return decorator 108 | 109 | 110 | @functools.lru_cache() 111 | def get_clock_rate_in_khz(): 112 | try: 113 | return nvsmi(["clocks.max.sm"])[0] * 1e3 114 | except FileNotFoundError: 115 | import pynvml 116 | 117 | pynvml.nvmlInit() 118 | handle = pynvml.nvmlDeviceGetHandleByIndex(0) 119 | return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3 120 | 121 | 122 | def get_tensorcore_tflops(device, num_ctas, num_warps, dtype): 123 | """return compute throughput in TOPS""" 124 | total_warps = num_ctas * min(num_warps, 4) 125 | num_subcores = ( 126 | driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 127 | ) # on recent GPUs 128 | tflops = ( 129 | min(num_subcores, total_warps) 130 | / num_subcores 131 | * get_max_tensorcore_tflops(dtype, get_clock_rate_in_khz(), device) 132 | ) 133 | return tflops 134 | 135 | 136 | def get_simd_tflops(device, num_ctas, num_warps, dtype): 137 | """return compute throughput in TOPS""" 138 | total_warps = num_ctas * min(num_warps, 4) 139 | num_subcores = ( 140 | driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 141 | ) # on recent GPUs 142 | tflops = ( 143 | min(num_subcores, total_warps) 144 | / num_subcores 145 | * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device) 146 | ) 147 | return tflops 148 | 149 | 150 | def get_tflops(device, num_ctas, num_warps, dtype): 151 | capability = torch.cuda.get_device_capability(device) 152 | if capability[0] < 8 and dtype == torch.float32: 153 | return get_simd_tflops(device, num_ctas, num_warps, dtype) 154 | return get_tensorcore_tflops(device, num_ctas, num_warps, dtype) 155 | 156 | 157 | def early_config_prune( 158 | configs, 159 | named_args, 160 | *, 161 | shared_memory_factor: float = 1.0, 162 | max_num_warps: int | None = None, 163 | **kwargs, 164 | ): 165 | device = torch.cuda.current_device() 166 | capability = torch.cuda.get_device_capability() 167 | # BLOCK_B, BLOCK_V, BLOCK_D, SPLIT_K, num_warps, num_stages 168 | dtsize = named_args["E"].element_size() 169 | 170 | if max_num_warps is not None: 171 | configs = [config for config in configs if config.num_warps <= max_num_warps] 172 | 173 | # 1. make sure we have enough smem 174 | pruned_configs = [] 175 | for config in configs: 176 | kw = config.kwargs 177 | BLOCK_B, BLOCK_V, BLOCK_D, num_stages = ( 178 | kw["BLOCK_B"], 179 | kw["BLOCK_V"], 180 | kw["BLOCK_D"], 181 | config.num_stages, 182 | ) 183 | 184 | max_shared_memory = driver.active.utils.get_device_properties(device)["max_shared_mem"] 185 | required_shared_memory = ( 186 | shared_memory_factor * (BLOCK_B + BLOCK_V) * BLOCK_D * num_stages * dtsize 187 | ) 188 | if required_shared_memory > max_shared_memory: 189 | continue 190 | 191 | pruned_configs.append(config) 192 | 193 | configs = pruned_configs 194 | 195 | # group configs by (BLOCK_B,_N,_K, num_warps) 196 | configs_map = {} 197 | for config in configs: 198 | kw = config.kwargs 199 | BLOCK_B, BLOCK_V, BLOCK_D, num_warps, num_stages = ( 200 | kw["BLOCK_B"], 201 | kw["BLOCK_V"], 202 | kw["BLOCK_D"], 203 | config.num_warps, 204 | config.num_stages, 205 | ) 206 | 207 | key = (BLOCK_B, BLOCK_V, BLOCK_D, num_warps) 208 | if key in configs_map: 209 | configs_map[key].append((config, num_stages)) 210 | else: 211 | configs_map[key] = [(config, num_stages)] 212 | 213 | pruned_configs = [] 214 | for k, v in configs_map.items(): 215 | BLOCK_B, BLOCK_V, BLOCK_D, num_warps = k 216 | if capability[0] >= 8: 217 | # compute cycles (only works for ampere GPUs) 218 | mmas = BLOCK_B * BLOCK_V * BLOCK_D / (16 * 8 * 16) 219 | mma_cycles = mmas / min(4, num_warps) * 8 220 | 221 | ldgsts_latency = 300 # Does this matter? 222 | optimal_num_stages = ldgsts_latency / mma_cycles 223 | 224 | # nearest stages, prefer large #stages 225 | nearest = heapq.nsmallest( 226 | 2, 227 | v, 228 | key=lambda x: 10 + abs(x[1] - optimal_num_stages) 229 | if (x[1] - optimal_num_stages) < 0 230 | else x[1] - optimal_num_stages, 231 | ) 232 | 233 | for n in nearest: 234 | pruned_configs.append(n[0]) 235 | else: # Volta & Turing only supports num_stages <= 2 236 | random_config = v[0][0] 237 | random_config.num_stages = 2 238 | pruned_configs.append(random_config) 239 | return pruned_configs 240 | 241 | 242 | def _total_ops_fn(B, V, D) -> float: 243 | return 2 * B * V * D + 10 * B * V 244 | 245 | 246 | def _total_store_fn(B, V, D, dtsize, num_cta_b, num_cta_v): 247 | return B * dtsize 248 | 249 | 250 | def estimate_matmul_time( 251 | # backend, device, 252 | num_warps, 253 | num_stages, # 254 | E, 255 | B, 256 | V, 257 | D, # 258 | BLOCK_B, 259 | BLOCK_V, 260 | BLOCK_D, 261 | debug=False, 262 | total_ops_fn=_total_ops_fn, 263 | total_store_fn=_total_store_fn, 264 | **kwargs, # 265 | ): 266 | """return estimated running time in ms 267 | = max(compute, loading) + store""" 268 | device = torch.cuda.current_device() 269 | dtype = E.dtype 270 | dtsize = E.element_size() 271 | 272 | num_cta_b = cdiv(B, BLOCK_B) 273 | num_cta_v = cdiv(V, BLOCK_V) 274 | num_ctas = num_cta_b * num_cta_v 275 | 276 | # If the input is smaller than the block size 277 | B, V = max(B, BLOCK_B), max(V, BLOCK_V) 278 | 279 | # time to compute 280 | total_ops = total_ops_fn(B, V, D) 281 | total_ops = total_ops / (1024 * 1024 * 1024) # GOPS 282 | tput = get_tflops(device, num_ctas, num_warps, dtype) 283 | compute_ms = total_ops / tput 284 | 285 | # time to load data 286 | num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"] 287 | active_cta_ratio = min(1, num_ctas / num_sm) 288 | active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate 289 | active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% 290 | dram_bw = get_dram_gbps(device) * ( 291 | active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05 292 | ) # in GB/s 293 | l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) 294 | # assume 80% of (following) loads are in L2 cache 295 | load_a_dram = B * D * dtsize * (1 + 0.2 * (num_cta_v - 1)) 296 | load_a_l2 = B * D * dtsize * 0.8 * (num_cta_v - 1) 297 | load_b_dram = V * D * dtsize * (1 + 0.2 * (num_cta_b - 1)) 298 | load_b_l2 = V * D * dtsize * 0.8 * (num_cta_b - 1) 299 | # total 300 | total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB 301 | total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) 302 | # loading time in ms 303 | load_ms = total_dram / dram_bw + total_l2 / l2_bw 304 | 305 | # estimate storing time 306 | store_bw = dram_bw * 0.4 # :o 307 | store_dram = total_store_fn(B, V, D, dtsize, num_cta_b, num_cta_v) / (1024 * 1024) 308 | store_ms = store_dram / store_bw 309 | 310 | total_time_ms = max(compute_ms, load_ms) + store_ms 311 | if debug: 312 | print( 313 | f"{BLOCK_B=}, {BLOCK_V=}, {BLOCK_D=}, {num_warps=}, {num_stages=}, " 314 | f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, " 315 | f"loading time: {load_ms}ms, store time: {store_ms}ms, " 316 | f"Activate CTAs: {active_cta_ratio*100}%" 317 | ) 318 | return total_time_ms 319 | 320 | 321 | def get_configs_io_bound(): 322 | configs = [] 323 | for num_stages in [2, 3, 4, 5, 6]: 324 | for block_m in [16, 32]: 325 | for block_k in [32, 64]: 326 | for block_n in [32, 64, 128, 256]: 327 | num_warps = 2 if block_n <= 64 else 4 328 | configs.append( 329 | Config( 330 | { 331 | "BLOCK_B": block_m, 332 | "BLOCK_V": block_n, 333 | "BLOCK_D": block_k, 334 | }, 335 | num_stages=num_stages, 336 | num_warps=num_warps, 337 | ) 338 | ) 339 | return configs 340 | 341 | 342 | def get_autotune_config(): 343 | return [ 344 | # basic configs for compute-bound matmuls 345 | Config( 346 | {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 128}, 347 | num_stages=2, 348 | num_warps=4, 349 | ), 350 | Config( 351 | {"BLOCK_B": 128, "BLOCK_V": 256, "BLOCK_D": 32}, 352 | num_stages=3, 353 | num_warps=8, 354 | ), 355 | Config( 356 | {"BLOCK_B": 256, "BLOCK_V": 128, "BLOCK_D": 32}, 357 | num_stages=3, 358 | num_warps=8, 359 | ), 360 | Config( 361 | {"BLOCK_B": 256, "BLOCK_V": 64, "BLOCK_D": 32}, 362 | num_stages=4, 363 | num_warps=4, 364 | ), 365 | Config( 366 | {"BLOCK_B": 64, "BLOCK_V": 256, "BLOCK_D": 32}, 367 | num_stages=4, 368 | num_warps=4, 369 | ), 370 | Config( 371 | {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 32}, 372 | num_stages=4, 373 | num_warps=4, 374 | ), 375 | Config( 376 | {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 32}, 377 | num_stages=3, 378 | num_warps=8, 379 | ), 380 | Config( 381 | {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 32}, 382 | num_stages=4, 383 | num_warps=8, 384 | ), 385 | Config( 386 | {"BLOCK_B": 128, "BLOCK_V": 64, "BLOCK_D": 32}, 387 | num_stages=4, 388 | num_warps=4, 389 | ), 390 | Config( 391 | {"BLOCK_B": 64, "BLOCK_V": 128, "BLOCK_D": 32}, 392 | num_stages=4, 393 | num_warps=4, 394 | ), 395 | Config( 396 | {"BLOCK_B": 128, "BLOCK_V": 32, "BLOCK_D": 32}, 397 | num_stages=4, 398 | num_warps=4, 399 | ), 400 | Config({"BLOCK_B": 64, "BLOCK_V": 32, "BLOCK_D": 32}, num_stages=5, num_warps=2), 401 | # good for int8 402 | Config( 403 | {"BLOCK_B": 128, "BLOCK_V": 256, "BLOCK_D": 128}, 404 | num_stages=3, 405 | num_warps=8, 406 | ), 407 | Config( 408 | {"BLOCK_B": 128, "BLOCK_V": 256, "BLOCK_D": 128}, 409 | num_stages=3, 410 | num_warps=16, 411 | ), 412 | Config( 413 | {"BLOCK_B": 256, "BLOCK_V": 128, "BLOCK_D": 128}, 414 | num_stages=3, 415 | num_warps=8, 416 | ), 417 | Config( 418 | {"BLOCK_B": 256, "BLOCK_V": 128, "BLOCK_D": 128}, 419 | num_stages=3, 420 | num_warps=16, 421 | ), 422 | Config( 423 | {"BLOCK_B": 256, "BLOCK_V": 64, "BLOCK_D": 128}, 424 | num_stages=4, 425 | num_warps=4, 426 | ), 427 | Config( 428 | {"BLOCK_B": 64, "BLOCK_V": 256, "BLOCK_D": 128}, 429 | num_stages=4, 430 | num_warps=4, 431 | ), 432 | Config( 433 | {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 128}, 434 | num_stages=4, 435 | num_warps=4, 436 | ), 437 | Config( 438 | {"BLOCK_B": 128, "BLOCK_V": 64, "BLOCK_D": 64}, 439 | num_stages=4, 440 | num_warps=4, 441 | ), 442 | Config( 443 | {"BLOCK_B": 64, "BLOCK_V": 128, "BLOCK_D": 64}, 444 | num_stages=4, 445 | num_warps=4, 446 | ), 447 | Config( 448 | {"BLOCK_B": 128, "BLOCK_V": 32, "BLOCK_D": 64}, 449 | num_stages=4, 450 | num_warps=4, 451 | ), 452 | Config({"BLOCK_B": 64, "BLOCK_V": 32, "BLOCK_D": 64}, num_stages=5, num_warps=2), 453 | ] + get_configs_io_bound() 454 | 455 | 456 | def _heuristics_from_config(config: Config) -> Callable[..., autotuner.Heuristics]: 457 | return triton.heuristics({k: (lambda args, _v=v: _v) for k, v in config.all_kwargs().items()}) 458 | 459 | 460 | def _cce_forward_best_config() -> Config: 461 | return Config(dict(BLOCK_B=256, BLOCK_V=128, BLOCK_D=32), num_warps=8, num_stages=3) 462 | 463 | 464 | def cce_forward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]: 465 | if _AUTOTUNE: 466 | return _cce_autotune( 467 | configs=get_autotune_config(), 468 | key=["V", "D", "B_BIN"], 469 | prune_configs_by={ 470 | "early_config_prune": early_config_prune, 471 | "perf_model": estimate_matmul_time, 472 | "top_k": 10, 473 | }, 474 | restore_value=["LSE"], 475 | reset_to_zero=["LA"], 476 | ) 477 | else: 478 | return _heuristics_from_config(_cce_forward_best_config()) 479 | 480 | 481 | def _bw_total_ops_fn(B, V, D) -> float: 482 | return 2 * B * V * D + 6 * B * V + 0.2 * (2 * B * V * D + 2 * B * V * D) 483 | 484 | 485 | def _bw_total_store_fn(B, V, D, dtsize, num_cta_b, num_cta_v): 486 | return 0.2 * (num_cta_v * B * D * dtsize + num_cta_b * D * V * dtsize) 487 | 488 | 489 | def _cce_backward_best_config() -> Config: 490 | return Config(dict(BLOCK_B=128, BLOCK_V=128, BLOCK_D=32), num_warps=4, num_stages=4) 491 | 492 | 493 | def cce_backward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]: 494 | if _AUTOTUNE: 495 | return _cce_autotune( 496 | configs=get_autotune_config(), 497 | key=["V", "D", "B_BIN"], 498 | prune_configs_by={ 499 | "early_config_prune": functools.partial( 500 | early_config_prune, shared_memory_factor=2.0 501 | ), 502 | "perf_model": functools.partial( 503 | estimate_matmul_time, 504 | total_ops_fn=_bw_total_ops_fn, 505 | total_store_fn=_bw_total_store_fn, 506 | ), 507 | "top_k": 5, 508 | }, 509 | reset_to_zero=["dE", "dC", "dEC", "dCC", "dBias"], 510 | ) 511 | else: 512 | return _heuristics_from_config(_cce_backward_best_config()) 513 | 514 | 515 | def _indexed_dot_best_config() -> Config: 516 | return Config(dict(BLOCK_B=128, BLOCK_D=256), num_warps=16, num_stages=4) 517 | 518 | 519 | def _indexed_dot_all_configs() -> list[Config]: 520 | return [ 521 | Config( 522 | dict( 523 | BLOCK_B=128, 524 | BLOCK_D=128, 525 | ), 526 | num_warps=4, 527 | num_stages=4, 528 | ), 529 | Config( 530 | dict( 531 | BLOCK_B=128, 532 | BLOCK_D=128, 533 | ), 534 | num_warps=8, 535 | num_stages=4, 536 | ), 537 | Config( 538 | dict( 539 | BLOCK_B=256, 540 | BLOCK_D=256, 541 | ), 542 | num_warps=16, 543 | num_stages=4, 544 | ), 545 | Config( 546 | dict( 547 | BLOCK_B=256, 548 | BLOCK_D=128, 549 | ), 550 | num_warps=16, 551 | num_stages=4, 552 | ), 553 | Config( 554 | dict( 555 | BLOCK_B=128, 556 | BLOCK_D=256, 557 | ), 558 | num_warps=16, 559 | num_stages=4, 560 | ), 561 | ] 562 | 563 | 564 | def indexed_dot_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]: 565 | if _AUTOTUNE: 566 | return _cce_autotune( 567 | configs=_indexed_dot_all_configs(), 568 | key=["D", "B_BIN"], 569 | reset_to_zero=["Out"], 570 | ) 571 | else: 572 | return _heuristics_from_config(_indexed_dot_best_config()) 573 | -------------------------------------------------------------------------------- /cut_cross_entropy/tl_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import functools 3 | 4 | import triton 5 | import triton.language as tl 6 | from triton.language.extra import libdevice as tl_libdevice 7 | 8 | from cut_cross_entropy.utils import is_package_greater_or_equal 9 | 10 | 11 | @triton.jit 12 | def tl_and_reduce_fn(a, b): 13 | return a & b 14 | 15 | 16 | @triton.jit 17 | def tl_tanh(a: tl.tensor) -> tl.tensor: 18 | return tl_libdevice.tanh(a) 19 | 20 | 21 | @triton.jit 22 | def tl_log1p(a: tl.tensor) -> tl.tensor: 23 | return tl_libdevice.log1p(a) 24 | 25 | 26 | @triton.jit 27 | def tl_softcapping(v: tl.tensor, softcap: float) -> tl.tensor: 28 | return tl_tanh(v / softcap) * softcap 29 | 30 | 31 | @triton.jit 32 | def tl_softcapping_grad(dv: tl.tensor, v: tl.tensor, softcap: float) -> tl.tensor: 33 | v = v / softcap 34 | return dv * (1 - v * v) 35 | 36 | 37 | @triton.jit 38 | def tl_logaddexp(a, b) -> tl.tensor: 39 | minx = tl.minimum(a, b) 40 | mx = tl.maximum(a, b) 41 | return tl_log1p(tl.exp(minx - mx)) + mx 42 | 43 | 44 | @triton.jit 45 | def tl_2sum(a: tl.tensor, b: tl.tensor) -> tuple[tl.tensor, tl.tensor]: 46 | s = a + b 47 | 48 | a_prime = s - b 49 | b_prime = s - a_prime 50 | 51 | delta_a = a - a_prime 52 | delta_b = b - b_prime 53 | 54 | t = delta_a + delta_b 55 | return s, t 56 | 57 | 58 | @triton.jit 59 | def tl_lock_kahan_sum(ptrs, c_ptrs, v, mask, lock_ptr): 60 | while tl.atomic_cas(lock_ptr, 0, 1) == 1: 61 | pass 62 | 63 | s = tl.load(ptrs, mask=mask, other=0.0, eviction_policy="evict_last") 64 | c = tl.load(c_ptrs, mask=mask, other=0.0, eviction_policy="evict_last") 65 | 66 | s, c = tl_2sum(s, c + v) 67 | 68 | tl.store(ptrs, s, mask=mask, eviction_policy="evict_last") 69 | tl.store(c_ptrs, c, mask=mask, eviction_policy="evict_last") 70 | 71 | tl.debug_barrier() 72 | tl.atomic_xchg(lock_ptr, 0) 73 | 74 | 75 | @triton.jit 76 | def tl_lock_add(ptrs, v, mask, lock_ptr): 77 | while tl.atomic_cas(lock_ptr, 0, 1) == 1: 78 | pass 79 | 80 | cur_v = tl.load(ptrs, mask=mask, other=0.0, eviction_policy="evict_last") 81 | new_v = v + cur_v 82 | tl.store(ptrs, new_v, mask=mask, eviction_policy="evict_last") 83 | 84 | tl.debug_barrier() 85 | tl.atomic_xchg(lock_ptr, 0) 86 | 87 | 88 | def b_bin_fn(b: int) -> int: 89 | if b >= 1024: 90 | return 1024 91 | elif b <= 128: 92 | return 128 93 | else: 94 | return 512 95 | 96 | 97 | @functools.cache 98 | def is_triton_greater_or_equal_3_2_0() -> bool: 99 | return is_package_greater_or_equal("triton", "3.2.0") 100 | -------------------------------------------------------------------------------- /cut_cross_entropy/torch_compile.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from cut_cross_entropy.constants import IGNORE_INDEX 6 | from cut_cross_entropy.doc import LINEAR_CROSS_ENTROPY_DOC, add_doc_start 7 | from cut_cross_entropy.utils import ( 8 | _build_flat_valids, 9 | handle_reduction_none, 10 | softcapping, 11 | ) 12 | from cut_cross_entropy.vocab_parallel import ( 13 | VocabParallelOptions, 14 | vocab_parallel_torch_compile_lce_apply, 15 | ) 16 | 17 | 18 | @torch.compile(fullgraph=True) 19 | def torch_compile_linear_cross_entropy_apply( 20 | e: torch.Tensor, 21 | c: torch.Tensor, 22 | targets: torch.Tensor, 23 | bias: torch.Tensor | None = None, 24 | softcap: float | None = None, 25 | *, 26 | ignore_index: int = IGNORE_INDEX, 27 | reduction: str = "mean", 28 | ) -> torch.Tensor: 29 | logits = e @ c.T 30 | 31 | if bias is not None: 32 | logits = logits + bias 33 | 34 | if softcap is not None: 35 | logits = softcapping(logits, softcap) 36 | 37 | loss = F.cross_entropy(logits.float(), targets, ignore_index=ignore_index, reduction=reduction) 38 | 39 | return loss 40 | 41 | 42 | @add_doc_start(LINEAR_CROSS_ENTROPY_DOC) 43 | def torch_compile_linear_cross_entropy( 44 | e: torch.Tensor, 45 | c: torch.Tensor, 46 | targets: torch.Tensor, 47 | bias: torch.Tensor | None = None, 48 | ignore_index: int = IGNORE_INDEX, 49 | softcap: float | None = None, 50 | reduction: str = "mean", 51 | shift: bool | int = 0, 52 | vocab_parallel_options: VocabParallelOptions | None = None, 53 | ) -> torch.Tensor: 54 | assert e.size()[0:-1] == targets.size() 55 | assert e.size(-1) == c.size(1) 56 | 57 | orig_b_size = targets.size() 58 | e = e.contiguous() 59 | targets = targets.contiguous() 60 | 61 | shift = int(shift) 62 | valids = _build_flat_valids(targets, ignore_index, shift) 63 | 64 | e = e.flatten(0, -2) 65 | targets = targets.flatten() 66 | 67 | if valids is not None: 68 | e = e[valids] 69 | targets = targets[(valids + shift) if shift != 0 else valids] 70 | 71 | if vocab_parallel_options is None: 72 | loss = torch_compile_linear_cross_entropy_apply( 73 | e, 74 | c, 75 | targets, 76 | bias, 77 | softcap, 78 | ignore_index=ignore_index, 79 | reduction=reduction, 80 | ) 81 | else: 82 | loss = vocab_parallel_torch_compile_lce_apply( 83 | vocab_parallel_options, e, c, targets, bias, softcap, reduction 84 | ) 85 | 86 | if reduction == "none": 87 | loss = handle_reduction_none(orig_b_size, valids, shift, loss) 88 | 89 | if shift != 0: 90 | loss = loss[..., shift:] 91 | 92 | return loss 93 | -------------------------------------------------------------------------------- /cut_cross_entropy/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from cut_cross_entropy.transformers.patch import cce_patch 3 | 4 | __all__ = ["cce_patch"] 5 | -------------------------------------------------------------------------------- /cut_cross_entropy/transformers/gemma2.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from types import MethodType 3 | from typing import Optional, Tuple, Union 4 | 5 | import torch 6 | import transformers 7 | from transformers.cache_utils import HybridCache 8 | from transformers.modeling_outputs import CausalLMOutputWithPast 9 | from transformers.models.gemma2.modeling_gemma2 import ( 10 | _CONFIG_FOR_DOC, 11 | GEMMA2_INPUTS_DOCSTRING, 12 | logger, 13 | ) 14 | from transformers.utils import ( 15 | add_start_docstrings_to_model_forward, 16 | replace_return_docstrings, 17 | ) 18 | 19 | from .utils import PatchOptions, TransformersModelT, apply_lce 20 | 21 | _PATCH_OPTS: PatchOptions | None = None 22 | 23 | 24 | @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) 25 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 26 | def cce_forward( 27 | self, 28 | input_ids: torch.LongTensor = None, 29 | attention_mask: Optional[torch.Tensor] = None, 30 | position_ids: Optional[torch.LongTensor] = None, 31 | past_key_values: Optional[HybridCache] = None, 32 | inputs_embeds: Optional[torch.FloatTensor] = None, 33 | labels: Optional[torch.LongTensor] = None, 34 | use_cache: Optional[bool] = None, 35 | output_attentions: Optional[bool] = None, 36 | output_hidden_states: Optional[bool] = None, 37 | return_dict: Optional[bool] = None, 38 | cache_position: Optional[torch.LongTensor] = None, 39 | num_logits_to_keep: int = 0, 40 | **loss_kwargs, 41 | ) -> Union[Tuple, CausalLMOutputWithPast]: 42 | r""" 43 | Args: 44 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 45 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 46 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 47 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 48 | 49 | num_logits_to_keep (`int`, *optional*): 50 | Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all 51 | `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that 52 | token can save memory, which becomes pretty significant for long sequences or large vocabulary size. 53 | 54 | Returns: 55 | 56 | Example: 57 | 58 | ```python 59 | >>> from transformers import AutoTokenizer, GemmaForCausalLM 60 | 61 | >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") 62 | >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") 63 | 64 | >>> prompt = "What is your favorite condiment?" 65 | >>> inputs = tokenizer(prompt, return_tensors="pt") 66 | 67 | >>> # Generate 68 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 69 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 70 | "What is your favorite condiment?" 71 | ```""" 72 | 73 | if self.training and self.config._attn_implementation != "eager": 74 | logger.warning_once( 75 | "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " 76 | f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." 77 | ) 78 | output_attentions = ( 79 | output_attentions if output_attentions is not None else self.config.output_attentions 80 | ) 81 | output_hidden_states = ( 82 | output_hidden_states 83 | if output_hidden_states is not None 84 | else self.config.output_hidden_states 85 | ) 86 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 87 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 88 | outputs = self.model( 89 | input_ids=input_ids, 90 | attention_mask=attention_mask, 91 | position_ids=position_ids, 92 | past_key_values=past_key_values, 93 | inputs_embeds=inputs_embeds, 94 | use_cache=use_cache, 95 | output_attentions=output_attentions, 96 | output_hidden_states=output_hidden_states, 97 | return_dict=return_dict, 98 | cache_position=cache_position, 99 | **loss_kwargs, 100 | ) 101 | 102 | hidden_states = outputs[0] 103 | loss = None 104 | logits = None 105 | 106 | if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): 107 | assert labels is not None 108 | loss = apply_lce(hidden_states, self.lm_head.weight, labels, _PATCH_OPTS, **loss_kwargs) 109 | else: 110 | # Only compute necessary logits, and do not upcast them to float if we are not computing the loss 111 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 112 | if self.config.final_logit_softcapping is not None: 113 | logits = logits / self.config.final_logit_softcapping 114 | logits = torch.tanh(logits) 115 | logits = logits * self.config.final_logit_softcapping 116 | 117 | if labels is not None: 118 | loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) 119 | 120 | if not return_dict: 121 | output = (logits,) + outputs[1:] 122 | return (loss,) + output if loss is not None else output 123 | 124 | return CausalLMOutputWithPast( 125 | loss=loss, 126 | logits=logits, 127 | past_key_values=outputs.past_key_values, 128 | hidden_states=outputs.hidden_states, 129 | attentions=outputs.attentions, 130 | ) 131 | 132 | 133 | def patch_gemma2( 134 | maybe_model: TransformersModelT | str | transformers.PretrainedConfig, 135 | patch_options: PatchOptions, 136 | ) -> TransformersModelT | None: 137 | global _PATCH_OPTS 138 | from transformers.models.gemma2 import modeling_gemma2 139 | 140 | _PATCH_OPTS = patch_options 141 | 142 | if isinstance(maybe_model, transformers.PreTrainedModel): 143 | assert isinstance( 144 | maybe_model, modeling_gemma2.Gemma2ForCausalLM 145 | ), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}." 146 | maybe_model.forward = MethodType(cce_forward, maybe_model) 147 | return maybe_model 148 | else: 149 | modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward 150 | -------------------------------------------------------------------------------- /cut_cross_entropy/transformers/llama.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from types import MethodType 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import transformers 8 | from transformers.cache_utils import Cache 9 | from transformers.modeling_outputs import CausalLMOutputWithPast 10 | from transformers.models.llama.modeling_llama import ( 11 | _CONFIG_FOR_DOC, 12 | LLAMA_INPUTS_DOCSTRING, 13 | KwargsForCausalLM, 14 | Unpack, 15 | ) 16 | from transformers.utils import ( 17 | add_start_docstrings_to_model_forward, 18 | replace_return_docstrings, 19 | ) 20 | 21 | from .utils import PatchOptions, TransformersModelT, apply_lce 22 | 23 | _PATCH_OPTS: PatchOptions | None = None 24 | 25 | 26 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 27 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 28 | def cce_forward( 29 | self, 30 | input_ids: torch.LongTensor = None, 31 | attention_mask: Optional[torch.Tensor] = None, 32 | position_ids: Optional[torch.LongTensor] = None, 33 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 34 | inputs_embeds: Optional[torch.FloatTensor] = None, 35 | labels: Optional[torch.LongTensor] = None, 36 | use_cache: Optional[bool] = None, 37 | output_attentions: Optional[bool] = None, 38 | output_hidden_states: Optional[bool] = None, 39 | return_dict: Optional[bool] = None, 40 | cache_position: Optional[torch.LongTensor] = None, 41 | num_logits_to_keep: int = 0, 42 | **kwargs: Unpack[KwargsForCausalLM], 43 | ) -> Union[Tuple, CausalLMOutputWithPast]: 44 | r""" 45 | Args: 46 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 47 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 48 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 49 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 50 | 51 | num_logits_to_keep (`int`, *optional*): 52 | Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all 53 | `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that 54 | token can save memory, which becomes pretty significant for long sequences or large vocabulary size. 55 | 56 | Returns: 57 | 58 | Example: 59 | 60 | ```python 61 | >>> from transformers import AutoTokenizer, LlamaForCausalLM 62 | 63 | >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") 64 | >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") 65 | 66 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 67 | >>> inputs = tokenizer(prompt, return_tensors="pt") 68 | 69 | >>> # Generate 70 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 71 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 72 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 73 | ```""" 74 | output_attentions = ( 75 | output_attentions if output_attentions is not None else self.config.output_attentions 76 | ) 77 | output_hidden_states = ( 78 | output_hidden_states 79 | if output_hidden_states is not None 80 | else self.config.output_hidden_states 81 | ) 82 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 83 | 84 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 85 | outputs = self.model( 86 | input_ids=input_ids, 87 | attention_mask=attention_mask, 88 | position_ids=position_ids, 89 | past_key_values=past_key_values, 90 | inputs_embeds=inputs_embeds, 91 | use_cache=use_cache, 92 | output_attentions=output_attentions, 93 | output_hidden_states=output_hidden_states, 94 | return_dict=return_dict, 95 | cache_position=cache_position, 96 | **kwargs, 97 | ) 98 | 99 | hidden_states = outputs[0] 100 | loss = None 101 | logits = None 102 | 103 | if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): 104 | assert labels is not None 105 | loss = apply_lce(hidden_states, self.lm_head.weight, labels, _PATCH_OPTS, **kwargs) 106 | else: 107 | if self.config.pretraining_tp > 1: 108 | lm_head_slices = self.lm_head.weight.split( 109 | self.vocab_size // self.config.pretraining_tp, dim=0 110 | ) 111 | logits = [ 112 | F.linear(hidden_states, lm_head_slices[i]) 113 | for i in range(self.config.pretraining_tp) 114 | ] 115 | logits = torch.cat(logits, dim=-1) 116 | else: 117 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 118 | 119 | if labels is not None: 120 | loss = self.loss_function( 121 | logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs 122 | ) 123 | 124 | if not return_dict: 125 | output = (logits,) + outputs[1:] 126 | return (loss,) + output if loss is not None else output 127 | 128 | return CausalLMOutputWithPast( 129 | loss=loss, 130 | logits=logits, 131 | past_key_values=outputs.past_key_values, 132 | hidden_states=outputs.hidden_states, 133 | attentions=outputs.attentions, 134 | ) 135 | 136 | 137 | def patch_llama( 138 | maybe_model: TransformersModelT | str | transformers.PretrainedConfig, 139 | patch_options: PatchOptions, 140 | ) -> TransformersModelT | None: 141 | global _PATCH_OPTS 142 | from transformers.models.llama import modeling_llama 143 | 144 | _PATCH_OPTS = patch_options 145 | 146 | if isinstance(maybe_model, transformers.PreTrainedModel): 147 | assert isinstance( 148 | maybe_model, modeling_llama.LlamaForCausalLM 149 | ), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}." 150 | maybe_model.forward = MethodType(cce_forward, maybe_model) 151 | return maybe_model 152 | else: 153 | modeling_llama.LlamaForCausalLM.forward = cce_forward 154 | -------------------------------------------------------------------------------- /cut_cross_entropy/transformers/mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from types import MethodType 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | import torch.distributed 7 | import transformers 8 | from transformers.cache_utils import Cache 9 | from transformers.modeling_outputs import CausalLMOutputWithPast 10 | from transformers.models.mistral.modeling_mistral import ( 11 | _CONFIG_FOR_DOC, 12 | MISTRAL_INPUTS_DOCSTRING, 13 | KwargsForCausalLM, 14 | Unpack, 15 | ) 16 | from transformers.utils import ( 17 | add_start_docstrings_to_model_forward, 18 | replace_return_docstrings, 19 | ) 20 | 21 | from .utils import PatchOptions, TransformersModelT, apply_lce 22 | 23 | _PATCH_OPTS: PatchOptions | None = None 24 | 25 | 26 | @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) 27 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 28 | def cce_forward( 29 | self, 30 | input_ids: torch.LongTensor = None, 31 | attention_mask: Optional[torch.Tensor] = None, 32 | position_ids: Optional[torch.LongTensor] = None, 33 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 34 | inputs_embeds: Optional[torch.FloatTensor] = None, 35 | labels: Optional[torch.LongTensor] = None, 36 | use_cache: Optional[bool] = None, 37 | output_attentions: Optional[bool] = None, 38 | output_hidden_states: Optional[bool] = None, 39 | return_dict: Optional[bool] = None, 40 | cache_position: Optional[torch.LongTensor] = None, 41 | num_logits_to_keep: int = 0, 42 | **kwargs: Unpack[KwargsForCausalLM], 43 | ) -> Union[Tuple, CausalLMOutputWithPast]: 44 | r""" 45 | Args: 46 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 47 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 48 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 49 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 50 | 51 | num_logits_to_keep (`int`, *optional*): 52 | Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all 53 | `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that 54 | token can save memory, which becomes pretty significant for long sequences or large vocabulary size. 55 | 56 | Returns: 57 | 58 | Example: 59 | 60 | ```python 61 | >>> from transformers import AutoTokenizer, MistralForCausalLM 62 | 63 | >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") 64 | >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") 65 | 66 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 67 | >>> inputs = tokenizer(prompt, return_tensors="pt") 68 | 69 | >>> # Generate 70 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 71 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 72 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 73 | ```""" 74 | 75 | output_attentions = ( 76 | output_attentions if output_attentions is not None else self.config.output_attentions 77 | ) 78 | output_hidden_states = ( 79 | output_hidden_states 80 | if output_hidden_states is not None 81 | else self.config.output_hidden_states 82 | ) 83 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 84 | 85 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 86 | outputs = self.model( 87 | input_ids=input_ids, 88 | attention_mask=attention_mask, 89 | position_ids=position_ids, 90 | past_key_values=past_key_values, 91 | inputs_embeds=inputs_embeds, 92 | use_cache=use_cache, 93 | output_attentions=output_attentions, 94 | output_hidden_states=output_hidden_states, 95 | return_dict=return_dict, 96 | cache_position=cache_position, 97 | **kwargs, 98 | ) 99 | 100 | hidden_states = outputs[0] 101 | loss = None 102 | logits = None 103 | 104 | if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): 105 | assert labels is not None 106 | loss = apply_lce(hidden_states, self.lm_head.weight, labels, _PATCH_OPTS, **kwargs) 107 | else: 108 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 109 | 110 | if labels is not None: 111 | loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) 112 | 113 | if not return_dict: 114 | output = (logits,) + outputs[1:] 115 | return (loss,) + output if loss is not None else output 116 | 117 | return CausalLMOutputWithPast( 118 | loss=loss, 119 | logits=logits, 120 | past_key_values=outputs.past_key_values, 121 | hidden_states=outputs.hidden_states, 122 | attentions=outputs.attentions, 123 | ) 124 | 125 | 126 | def patch_mistral( 127 | maybe_model: TransformersModelT | str | transformers.PretrainedConfig, 128 | patch_options: PatchOptions, 129 | ) -> TransformersModelT | None: 130 | global _PATCH_OPTS 131 | from transformers.models.mistral import modeling_mistral 132 | 133 | _PATCH_OPTS = patch_options 134 | 135 | if isinstance(maybe_model, transformers.PreTrainedModel): 136 | assert isinstance( 137 | maybe_model, modeling_mistral.MistralForCausalLM 138 | ), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}." 139 | maybe_model.forward = MethodType(cce_forward, maybe_model) 140 | return maybe_model 141 | else: 142 | modeling_mistral.MistralForCausalLM.forward = cce_forward 143 | -------------------------------------------------------------------------------- /cut_cross_entropy/transformers/patch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from typing import overload 3 | 4 | import transformers 5 | 6 | from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl 7 | from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT 8 | 9 | from .gemma2 import patch_gemma2 10 | from .llama import patch_llama 11 | from .mistral import patch_mistral 12 | from .phi3 import patch_phi3 13 | from .qwen2 import patch_qwen2 14 | from .utils import PatchOptions, TransformersModelT 15 | 16 | PATCH_FNS = { 17 | "llama": patch_llama, 18 | "phi3": patch_phi3, 19 | "gemma2": patch_gemma2, 20 | "mistral": patch_mistral, 21 | "qwen2": patch_qwen2, 22 | } 23 | 24 | 25 | @overload 26 | def cce_patch( 27 | model_type_or_model: str | transformers.PretrainedConfig, 28 | impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT, 29 | reduction: str = "mean", 30 | filter_eps: float | str | None = "auto", 31 | accum_e_fp32: bool = False, 32 | accum_c_fp32: bool = False, 33 | filter_e_grad: bool = True, 34 | filter_c_grad: bool = True, 35 | train_only: bool = False, 36 | ) -> None: ... 37 | 38 | 39 | @overload 40 | def cce_patch( 41 | model_type_or_model: TransformersModelT, 42 | impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT, 43 | reduction: str = "mean", 44 | filter_eps: float | str | None = "auto", 45 | accum_e_fp32: bool = False, 46 | accum_c_fp32: bool = False, 47 | filter_e_grad: bool = True, 48 | filter_c_grad: bool = True, 49 | train_only: bool = False, 50 | ) -> TransformersModelT: ... 51 | 52 | 53 | def cce_patch( 54 | model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig, 55 | impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT, 56 | reduction: str = "mean", 57 | filter_eps: float | str | None = "auto", 58 | accum_e_fp32: bool = False, 59 | accum_c_fp32: bool = False, 60 | filter_e_grad: bool = True, 61 | filter_c_grad: bool = True, 62 | train_only: bool = False, 63 | ) -> TransformersModelT | None: 64 | if isinstance(impl, LinearCrossEntropyImpl): 65 | impl = impl.name.lower() 66 | 67 | if impl not in (v.name.lower() for v in LinearCrossEntropyImpl): 68 | raise ValueError(f"Unknown {impl=}") 69 | 70 | if isinstance(model_type_or_model, transformers.PreTrainedModel): 71 | model_type = model_type_or_model.config.model_type 72 | elif isinstance(model_type_or_model, transformers.PretrainedConfig): 73 | model_type = model_type_or_model.model_type 74 | else: 75 | model_type = model_type_or_model 76 | 77 | patch_options = PatchOptions( 78 | impl=impl, 79 | reduction=reduction, 80 | filter_eps=filter_eps, 81 | accum_e_fp32=accum_e_fp32, 82 | accum_c_fp32=accum_c_fp32, 83 | filter_e_grad=filter_e_grad, 84 | filter_c_grad=filter_c_grad, 85 | train_only=train_only, 86 | ) 87 | 88 | if model_type in PATCH_FNS: 89 | return PATCH_FNS[model_type](model_type_or_model, patch_options) 90 | else: 91 | raise RuntimeError(f"Unknown model type {model_type}") 92 | -------------------------------------------------------------------------------- /cut_cross_entropy/transformers/phi3.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from types import MethodType 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | import transformers 7 | from transformers.cache_utils import Cache 8 | from transformers.modeling_outputs import CausalLMOutputWithPast 9 | from transformers.models.phi3.modeling_phi3 import ( 10 | _CONFIG_FOR_DOC, 11 | PHI3_INPUTS_DOCSTRING, 12 | KwargsForCausalLM, 13 | Unpack, 14 | ) 15 | from transformers.utils import ( 16 | add_start_docstrings_to_model_forward, 17 | replace_return_docstrings, 18 | ) 19 | 20 | from .utils import PatchOptions, TransformersModelT, apply_lce 21 | 22 | _PATCH_OPTS: PatchOptions | None = None 23 | 24 | 25 | @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) 26 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 27 | def cce_forward( 28 | self, 29 | input_ids: torch.LongTensor = None, 30 | attention_mask: Optional[torch.Tensor] = None, 31 | position_ids: Optional[torch.LongTensor] = None, 32 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 33 | inputs_embeds: Optional[torch.FloatTensor] = None, 34 | labels: Optional[torch.LongTensor] = None, 35 | use_cache: Optional[bool] = None, 36 | output_attentions: Optional[bool] = None, 37 | output_hidden_states: Optional[bool] = None, 38 | return_dict: Optional[bool] = None, 39 | cache_position: Optional[torch.LongTensor] = None, 40 | num_logits_to_keep: int = 0, 41 | **kwargs: Unpack[KwargsForCausalLM], 42 | ) -> Union[Tuple, CausalLMOutputWithPast]: 43 | r""" 44 | Args: 45 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 46 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 47 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 48 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 49 | 50 | num_logits_to_keep (`int`, *optional*): 51 | Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all 52 | `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that 53 | token can save memory, which becomes pretty significant for long sequences or large vocabulary size. 54 | 55 | Returns: 56 | 57 | Example: 58 | 59 | ```python 60 | >>> from transformers import AutoTokenizer, Phi3ForCausalLM 61 | 62 | >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") 63 | >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") 64 | 65 | >>> prompt = "This is an example script ." 66 | >>> inputs = tokenizer(prompt, return_tensors="pt") 67 | 68 | >>> # Generate 69 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 70 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 71 | 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' 72 | ```""" 73 | output_attentions = ( 74 | output_attentions if output_attentions is not None else self.config.output_attentions 75 | ) 76 | output_hidden_states = ( 77 | output_hidden_states 78 | if output_hidden_states is not None 79 | else self.config.output_hidden_states 80 | ) 81 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 82 | 83 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 84 | outputs = self.model( 85 | input_ids=input_ids, 86 | attention_mask=attention_mask, 87 | position_ids=position_ids, 88 | past_key_values=past_key_values, 89 | inputs_embeds=inputs_embeds, 90 | use_cache=use_cache, 91 | output_attentions=output_attentions, 92 | output_hidden_states=output_hidden_states, 93 | return_dict=return_dict, 94 | cache_position=cache_position, 95 | **kwargs, 96 | ) 97 | 98 | hidden_states = outputs[0] 99 | loss = None 100 | logits = None 101 | 102 | if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): 103 | assert labels is not None 104 | loss = apply_lce(hidden_states, self.lm_head.weight, labels, _PATCH_OPTS, **kwargs) 105 | else: 106 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 107 | 108 | if labels is not None: 109 | loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) 110 | 111 | if not return_dict: 112 | output = (logits,) + outputs[1:] 113 | return (loss,) + output if loss is not None else output 114 | 115 | return CausalLMOutputWithPast( 116 | loss=loss, 117 | logits=logits, 118 | past_key_values=outputs.past_key_values, 119 | hidden_states=outputs.hidden_states, 120 | attentions=outputs.attentions, 121 | ) 122 | 123 | 124 | def patch_phi3( 125 | maybe_model: TransformersModelT | str | transformers.PretrainedConfig, 126 | patch_options: PatchOptions, 127 | ) -> TransformersModelT | None: 128 | global _PATCH_OPTS 129 | from transformers.models.phi3 import modeling_phi3 130 | 131 | _PATCH_OPTS = patch_options 132 | 133 | if isinstance(maybe_model, transformers.PreTrainedModel): 134 | assert isinstance( 135 | maybe_model, modeling_phi3.Phi3ForCausalLM 136 | ), f"Expected a Phi3ForCausalLM model. Got {type(maybe_model)}." 137 | maybe_model.forward = MethodType(cce_forward, maybe_model) 138 | return maybe_model 139 | else: 140 | modeling_phi3.Phi3ForCausalLM.forward = cce_forward 141 | -------------------------------------------------------------------------------- /cut_cross_entropy/transformers/qwen2.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from types import MethodType 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | import transformers 7 | from transformers.cache_utils import Cache 8 | from transformers.modeling_outputs import CausalLMOutputWithPast 9 | from transformers.models.qwen2.modeling_qwen2 import ( 10 | _CONFIG_FOR_DOC, 11 | QWEN2_INPUTS_DOCSTRING, 12 | KwargsForCausalLM, 13 | Unpack, 14 | ) 15 | from transformers.utils import ( 16 | add_start_docstrings_to_model_forward, 17 | replace_return_docstrings, 18 | ) 19 | 20 | from .utils import PatchOptions, TransformersModelT, apply_lce 21 | 22 | _PATCH_OPTS: PatchOptions | None = None 23 | 24 | 25 | @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) 26 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 27 | def cce_forward( 28 | self, 29 | input_ids: torch.LongTensor = None, 30 | attention_mask: Optional[torch.Tensor] = None, 31 | position_ids: Optional[torch.LongTensor] = None, 32 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 33 | inputs_embeds: Optional[torch.FloatTensor] = None, 34 | labels: Optional[torch.LongTensor] = None, 35 | use_cache: Optional[bool] = None, 36 | output_attentions: Optional[bool] = None, 37 | output_hidden_states: Optional[bool] = None, 38 | return_dict: Optional[bool] = None, 39 | cache_position: Optional[torch.LongTensor] = None, 40 | num_logits_to_keep: int = 0, 41 | **kwargs: Unpack[KwargsForCausalLM], 42 | ) -> Union[Tuple, CausalLMOutputWithPast]: 43 | r""" 44 | Args: 45 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 46 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 47 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 48 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 49 | 50 | num_logits_to_keep (`int`, *optional*): 51 | Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all 52 | `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that 53 | token can save memory, which becomes pretty significant for long sequences or large vocabulary size. 54 | 55 | Returns: 56 | 57 | Example: 58 | 59 | ```python 60 | >>> from transformers import AutoTokenizer, Qwen2ForCausalLM 61 | 62 | >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 63 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 64 | 65 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 66 | >>> inputs = tokenizer(prompt, return_tensors="pt") 67 | 68 | >>> # Generate 69 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 70 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 71 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 72 | ```""" 73 | 74 | output_attentions = ( 75 | output_attentions if output_attentions is not None else self.config.output_attentions 76 | ) 77 | output_hidden_states = ( 78 | output_hidden_states 79 | if output_hidden_states is not None 80 | else self.config.output_hidden_states 81 | ) 82 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 83 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 84 | outputs = self.model( 85 | input_ids=input_ids, 86 | attention_mask=attention_mask, 87 | position_ids=position_ids, 88 | past_key_values=past_key_values, 89 | inputs_embeds=inputs_embeds, 90 | use_cache=use_cache, 91 | output_attentions=output_attentions, 92 | output_hidden_states=output_hidden_states, 93 | return_dict=return_dict, 94 | cache_position=cache_position, 95 | **kwargs, 96 | ) 97 | 98 | hidden_states = outputs[0] 99 | loss = None 100 | logits = None 101 | 102 | if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): 103 | assert labels is not None 104 | loss = apply_lce(hidden_states, self.lm_head.weight, labels, _PATCH_OPTS, **kwargs) 105 | else: 106 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 107 | if labels is not None: 108 | loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) 109 | 110 | if not return_dict: 111 | output = (logits,) + outputs[1:] 112 | return (loss,) + output if loss is not None else output 113 | 114 | return CausalLMOutputWithPast( 115 | loss=loss, 116 | logits=logits, 117 | past_key_values=outputs.past_key_values, 118 | hidden_states=outputs.hidden_states, 119 | attentions=outputs.attentions, 120 | ) 121 | 122 | 123 | def patch_qwen2( 124 | maybe_model: TransformersModelT | str | transformers.PretrainedConfig, 125 | patch_options: PatchOptions, 126 | ) -> TransformersModelT | None: 127 | global _PATCH_OPTS 128 | from transformers.models.qwen2 import modeling_qwen2 129 | 130 | _PATCH_OPTS = patch_options 131 | 132 | if isinstance(maybe_model, transformers.PreTrainedModel): 133 | assert isinstance( 134 | maybe_model, modeling_qwen2.Qwen2ForCausalLM 135 | ), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}." 136 | maybe_model.forward = MethodType(cce_forward, maybe_model) 137 | return maybe_model 138 | else: 139 | modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward 140 | -------------------------------------------------------------------------------- /cut_cross_entropy/transformers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from dataclasses import dataclass 3 | from typing import TypeVar 4 | 5 | import torch 6 | import transformers 7 | 8 | from cut_cross_entropy import linear_cross_entropy 9 | from cut_cross_entropy.cce_utils import CCEPreset 10 | 11 | TransformersModelT = TypeVar("TransformersModelT", bound=transformers.PreTrainedModel) 12 | 13 | 14 | class CCEKwargs(CCEPreset): 15 | impl: str 16 | reduction: str 17 | 18 | 19 | @dataclass 20 | class PatchOptions: 21 | impl: str 22 | reduction: str 23 | filter_eps: float | str | None 24 | accum_e_fp32: bool 25 | accum_c_fp32: bool 26 | filter_e_grad: bool 27 | filter_c_grad: bool 28 | train_only: bool 29 | 30 | def to_kwargs(self) -> CCEKwargs: 31 | return CCEKwargs( 32 | impl=self.impl, 33 | reduction=self.reduction, 34 | filter_eps=self.filter_eps, 35 | accum_e_fp32=self.accum_e_fp32, 36 | accum_c_fp32=self.accum_c_fp32, 37 | filter_e_grad=self.filter_e_grad, 38 | filter_c_grad=self.filter_c_grad, 39 | ) 40 | 41 | def use_lce(self, labels: torch.Tensor | None, training: bool) -> bool: 42 | if labels is None: 43 | return False 44 | 45 | if not training and self.train_only: 46 | return False 47 | 48 | return True 49 | 50 | 51 | def apply_lce( 52 | e: torch.Tensor, 53 | c: torch.Tensor, 54 | labels: torch.Tensor, 55 | opts: PatchOptions, 56 | bias: torch.Tensor | None = None, 57 | **loss_kwargs, 58 | ) -> torch.Tensor: 59 | num_items_in_batch = loss_kwargs.get("num_items_in_batch", None) 60 | cce_kwargs = opts.to_kwargs() 61 | if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean": 62 | cce_kwargs["reduction"] = "sum" 63 | else: 64 | num_items_in_batch = None 65 | 66 | loss = linear_cross_entropy( 67 | e, 68 | c, 69 | labels.to(e.device), 70 | bias=bias, 71 | shift=True, 72 | **cce_kwargs, 73 | ) 74 | 75 | if num_items_in_batch is not None: 76 | loss = loss / num_items_in_batch 77 | 78 | return loss 79 | -------------------------------------------------------------------------------- /cut_cross_entropy/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import functools 3 | import importlib.metadata 4 | 5 | import packaging.version 6 | import torch 7 | 8 | 9 | @torch.compile(fullgraph=True) 10 | def softcapping(logits: torch.Tensor, softcap: float) -> torch.Tensor: 11 | return torch.tanh(logits / softcap) * softcap 12 | 13 | 14 | def _handle_eps(filter_eps: float | str | None, dtype: torch.dtype) -> float | None: 15 | if filter_eps is None: 16 | return None 17 | elif isinstance(filter_eps, float): 18 | return filter_eps 19 | elif filter_eps == "auto": 20 | return torch.finfo(dtype).eps / 32 21 | else: 22 | raise RuntimeError(f"Unknown eps {filter_eps=}") 23 | 24 | 25 | def _build_flat_valids( 26 | targets: torch.Tensor, 27 | ignore_index: int, 28 | shift: int, 29 | ) -> torch.Tensor | None: 30 | if shift != 0: 31 | targets = targets[..., shift:] 32 | else: 33 | targets = targets.flatten() 34 | 35 | valids = (targets != ignore_index).nonzero().to(torch.int32) 36 | 37 | if shift == 0: 38 | assert valids.size(1) == 1 39 | return valids.squeeze(1) if valids.numel() != targets.numel() else None 40 | 41 | for i in range(targets.ndim - 1): 42 | valids[:, i] *= targets.stride(i) 43 | 44 | assert targets.stride(-1) == 1 45 | 46 | return valids.sum(1) 47 | 48 | 49 | def handle_reduction_none( 50 | batch_shape: torch.Size, valids: torch.Tensor | None, shift: int, loss: torch.Tensor 51 | ) -> torch.Tensor: 52 | if valids is None: 53 | return loss.view(batch_shape) 54 | 55 | full_loss = loss.new_zeros((batch_shape.numel(),)) 56 | full_loss[(valids + shift) if shift != 0 else valids] = loss 57 | 58 | return full_loss.view(batch_shape) 59 | 60 | 61 | @functools.cache 62 | def is_package_greater_or_equal(package: str, version: str) -> bool: 63 | return packaging.version.parse(importlib.metadata.version(package)) >= packaging.version.parse( 64 | version 65 | ) 66 | 67 | 68 | @functools.cache 69 | def is_torch_greater_or_equal_2_5() -> bool: 70 | return is_package_greater_or_equal("torch", "2.5") 71 | -------------------------------------------------------------------------------- /cut_cross_entropy/vocab_parallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from cut_cross_entropy.vocab_parallel.utils import VocabParallelOptions 3 | from cut_cross_entropy.vocab_parallel.vocab_parallel_torch_compile import ( 4 | vocab_parallel_torch_compile_lce_apply, 5 | ) 6 | 7 | __all__ = ["VocabParallelOptions", "vocab_parallel_torch_compile_lce_apply"] 8 | -------------------------------------------------------------------------------- /cut_cross_entropy/vocab_parallel/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | import torch.distributed 6 | from typing_extensions import Self 7 | 8 | 9 | def partition_n_into_range(n: int, rank: int, world_size: int) -> tuple[int, int]: 10 | start = rank * (n // world_size) + min(rank, n % world_size) 11 | stop = start + n // world_size + (1 if rank < (n % world_size) else 0) 12 | 13 | return start, stop 14 | 15 | 16 | @dataclass 17 | class VocabParallelOptions: 18 | """Options to configure vocab parallel loss computation 19 | 20 | :param start: The start index for this rank's range in the vocab. 21 | :param end: The ending index (non-inclusive) 22 | :param group: The distributed process group defining the world for this vocab parallel rank. 23 | :param reduce_e_grad: Whether or not to all_reduce/synchronize the gradient for the embedding 24 | matrix across all ranks. This typically should be true, but some frameworks may require setting this to false. 25 | """ 26 | 27 | start: int 28 | stop: int 29 | group: torch.distributed.ProcessGroup | None = None 30 | reduce_e_grad: bool = True 31 | 32 | @classmethod 33 | def from_vocab( 34 | cls, 35 | vocab_size: int, 36 | group: torch.distributed.ProcessGroup | None = None, 37 | reduce_e_grad: bool = True, 38 | ) -> Self: 39 | rank = torch.distributed.get_rank(group) 40 | world_size = torch.distributed.get_world_size(group) 41 | 42 | start, stop = partition_n_into_range(vocab_size, rank, world_size) 43 | 44 | return cls(start, stop, group, reduce_e_grad) 45 | 46 | 47 | @torch.compile(fullgraph=True) 48 | def vp_reduce_lse(vp_lse: torch.Tensor, pg: torch.distributed.ProcessGroup | None) -> torch.Tensor: 49 | lse_max = vp_lse.clone() 50 | torch.distributed.all_reduce(lse_max, op=torch.distributed.ReduceOp.MAX, group=pg) 51 | 52 | lse = (vp_lse - lse_max).exp() 53 | torch.distributed.all_reduce(lse, group=pg) 54 | return lse_max + lse.log() 55 | 56 | 57 | @torch.compile(fullgraph=True) 58 | def vp_reduce_correct_logit( 59 | vp_correct_logit: torch.Tensor, 60 | pg: torch.distributed.ProcessGroup | None, 61 | dtype: torch.dtype | None = None, 62 | ) -> torch.Tensor: 63 | correct_logit = vp_correct_logit.to(dtype=dtype, copy=True) 64 | torch.distributed.all_reduce(correct_logit, group=pg) 65 | return correct_logit 66 | 67 | 68 | @torch.compile(fullgraph=True) 69 | def vp_reduce_e_grad( 70 | e_grad: torch.Tensor, 71 | pg: torch.distributed.ProcessGroup | None, 72 | ) -> torch.Tensor: 73 | reduced_grad = e_grad.to(dtype=torch.float32, copy=True) 74 | torch.distributed.all_reduce(reduced_grad, group=pg) 75 | 76 | return reduced_grad.type_as(e_grad) 77 | 78 | 79 | class _VocabParallelReduceEGradHook(torch.autograd.Function): 80 | @staticmethod 81 | def forward( 82 | ctx, e: torch.Tensor, pg: torch.distributed.ProcessGroup | None = None 83 | ) -> torch.Tensor: 84 | ctx.pg = pg 85 | return e 86 | 87 | @staticmethod 88 | def backward(ctx, e_grad: torch.Tensor) -> tuple[torch.Tensor, None]: 89 | return vp_reduce_e_grad(e_grad, ctx.pg), None 90 | 91 | 92 | def vp_reduce_e_grad_hook( 93 | e: torch.Tensor, vocab_parallel_options: VocabParallelOptions 94 | ) -> torch.Tensor: 95 | if vocab_parallel_options.reduce_e_grad: 96 | return _VocabParallelReduceEGradHook.apply(e, vocab_parallel_options.group) 97 | else: 98 | return e 99 | -------------------------------------------------------------------------------- /cut_cross_entropy/vocab_parallel/vocab_parallel_torch_compile.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import torch 3 | import torch.distributed 4 | 5 | from cut_cross_entropy.utils import softcapping 6 | from cut_cross_entropy.vocab_parallel.utils import ( 7 | VocabParallelOptions, 8 | vp_reduce_correct_logit, 9 | vp_reduce_e_grad_hook, 10 | vp_reduce_lse, 11 | ) 12 | 13 | 14 | class _VocabParallelLossFunction(torch.autograd.Function): 15 | @staticmethod 16 | def forward( 17 | ctx, 18 | vp_correct_logit: torch.Tensor, 19 | vp_lse: torch.Tensor, 20 | pg: torch.distributed.ProcessGroup | None, 21 | ) -> torch.Tensor: 22 | lse = vp_reduce_lse(vp_lse, pg) 23 | correct_logit = vp_reduce_correct_logit(vp_correct_logit, pg, dtype=lse.dtype) 24 | 25 | ctx.save_for_backward(vp_lse, lse) 26 | 27 | return lse - correct_logit 28 | 29 | @staticmethod 30 | def backward( 31 | ctx, grad_loss: torch.Tensor 32 | ) -> tuple[torch.Tensor | None, torch.Tensor | None, None]: 33 | grad_correct_logit = -grad_loss 34 | 35 | vp_lse, lse = ctx.saved_tensors 36 | 37 | grad_lse = (vp_lse - lse).exp() * grad_loss 38 | 39 | return grad_correct_logit, grad_lse, None 40 | 41 | 42 | def _vp_loss_fn( 43 | vp_correct_logit: torch.Tensor, 44 | vp_lse: torch.Tensor, 45 | pg: torch.distributed.ProcessGroup | None, 46 | ) -> torch.Tensor: 47 | return _VocabParallelLossFunction.apply(vp_correct_logit, vp_lse, pg) 48 | 49 | 50 | def _vp_torch_compile_correct_logit_lse( 51 | e: torch.Tensor, 52 | vocab_parallel_c: torch.Tensor, 53 | targets: torch.Tensor, 54 | start: int, 55 | stop: int, 56 | vocab_parallel_bias: torch.Tensor | None = None, 57 | softcap: float | None = None, 58 | ) -> tuple[torch.Tensor, torch.Tensor]: 59 | vp_logits = e @ vocab_parallel_c.T 60 | 61 | if vocab_parallel_bias is not None: 62 | vp_logits = vp_logits + vocab_parallel_bias 63 | 64 | if softcap is not None: 65 | vp_logits = softcapping(vp_logits, softcap) 66 | 67 | vp_lse = torch.logsumexp(vp_logits.float(), -1) 68 | 69 | is_target_in_range = (targets < stop) & (targets >= start) 70 | masked_targets = torch.where(is_target_in_range, targets - start, targets.new_zeros(())) 71 | 72 | arange_indexer = torch.arange(0, len(vp_lse), device=targets.device, dtype=targets.dtype) 73 | vp_correct_logit = torch.where( 74 | is_target_in_range, vp_logits[arange_indexer, masked_targets], vp_logits.new_zeros(()) 75 | ) 76 | 77 | return vp_correct_logit, vp_lse 78 | 79 | 80 | @torch.compile(fullgraph=True) 81 | def vocab_parallel_torch_compile_lce_apply( 82 | vocab_parallel_options: VocabParallelOptions, 83 | e: torch.Tensor, 84 | vocab_parallel_c: torch.Tensor, 85 | targets: torch.Tensor, 86 | vocab_parallel_bias: torch.Tensor | None, 87 | softcap: float | None, 88 | reduction: str, 89 | ) -> torch.Tensor: 90 | pg = vocab_parallel_options.group 91 | 92 | e = vp_reduce_e_grad_hook(e, vocab_parallel_options) 93 | 94 | vp_correct_logit, vp_lse = _vp_torch_compile_correct_logit_lse( 95 | e, 96 | vocab_parallel_c, 97 | targets, 98 | vocab_parallel_options.start, 99 | vocab_parallel_options.stop, 100 | vocab_parallel_bias=vocab_parallel_bias, 101 | softcap=softcap, 102 | ) 103 | 104 | loss = _vp_loss_fn(vp_correct_logit, vp_lse, pg) 105 | 106 | if reduction == "none": 107 | pass 108 | elif reduction == "mean": 109 | loss = loss.mean() 110 | elif reduction == "sum": 111 | loss = loss.sum() 112 | else: 113 | raise ValueError(f"Unknown reduction {reduction!r}") 114 | 115 | return loss 116 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "cut-cross-entropy" 3 | dynamic = ["version"] 4 | readme = "README.md" 5 | description = "Code for cut cross entropy, a memory efficient implementation of linear-cross-entropy loss." 6 | requires-python = ">= 3.9" 7 | 8 | dependencies = [ 9 | "torch>=2.4", 10 | 'triton>=3.0.0 ; platform_system != "Darwin"', 11 | ] 12 | 13 | [project.optional-dependencies] 14 | test = [ 15 | "pytest", 16 | "pytest-sugar", 17 | "setuptools>=77.0.3", 18 | ] 19 | transformers = [ 20 | "transformers>=4.48.2", 21 | ] 22 | 23 | all = [ 24 | "cut-cross-entropy[transformers]", 25 | 'deepspeed>=0.15.1 ; platform_system != "Darwin"', 26 | "accelerate>=0.34.2", 27 | "datasets>=3.1.0", 28 | "huggingface_hub>=0.26.2", 29 | "pandas", 30 | "fire", 31 | "tqdm", 32 | ] 33 | 34 | dev = [ 35 | "cut-cross-entropy[all,test]", 36 | "build", 37 | "twine", 38 | "pre-commit", 39 | "pytest-xdist", 40 | ] 41 | 42 | 43 | [build-system] 44 | requires = ["setuptools >= 61.0", "setuptools-scm"] 45 | build-backend = "setuptools.build_meta" 46 | 47 | [tool.setuptools.packages.find] 48 | include = ["cut_cross_entropy*"] 49 | 50 | [tool.setuptools.dynamic] 51 | version = {attr = "cut_cross_entropy.__version__"} 52 | 53 | [tool.pytest.ini_options] 54 | minversion = "6.0" 55 | addopts = "--verbose -rsxX -q" 56 | testpaths = ["tests"] 57 | 58 | [tool.ruff] 59 | line-length = 100 60 | lint.select = ["F", "I"] 61 | 62 | [tool.ruff.lint.isort] 63 | known-first-party = ["cut_cross_entropy"] 64 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | name=phi3.5 6 | 7 | 8 | torchrun --standalone --nproc-per-node=8 --module training.train \ 9 | --deepspeed training/zero3.json \ 10 | --model_name $name \ 11 | --output_dir checkpoints \ 12 | --per_device_train_batch_size 2 \ 13 | --gradient_accumulation_steps 4 \ 14 | --per_device_eval_batch_size 8 \ 15 | --cross_entropy_impl cce \ 16 | --eval_strategy "no" \ 17 | --eval_steps 1000 \ 18 | --learning_rate 2e-5 \ 19 | --dataloader_num_workers 4 \ 20 | --run_name $name \ 21 | --report_to 'none' 22 | -------------------------------------------------------------------------------- /tests/test_cce_indexed_dot.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import pytest 3 | import torch 4 | 5 | from cut_cross_entropy.indexed_dot import indexed_neg_dot_forward_kernel 6 | from cut_cross_entropy.utils import softcapping 7 | 8 | skip_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA") 9 | 10 | 11 | @skip_no_cuda 12 | @pytest.mark.parametrize( 13 | "dtype,error_tol", [(torch.float32, 1e-6), (torch.float16, 1e-3), (torch.bfloat16, 1e-2)] 14 | ) 15 | @pytest.mark.parametrize("softcap", [None, 20.0]) 16 | @pytest.mark.parametrize("has_bias", [True, False]) 17 | @pytest.mark.parametrize("shape", [(256, 512, 128), (255, 507, 128), (255, 507, 123)]) 18 | def test_indexed_dot( 19 | dtype: torch.dtype, 20 | error_tol: float, 21 | softcap: float | None, 22 | has_bias: bool, 23 | shape: tuple[int, int, int], 24 | ): 25 | torch.cuda.manual_seed(0) 26 | 27 | if dtype == torch.bfloat16 and not torch.cuda.is_available(): 28 | pytest.skip(reason="BF16 not avaliable") 29 | 30 | N, V, D = shape 31 | e = torch.randn((N, D), device="cuda", dtype=dtype) / (D**0.5) 32 | c = torch.randn((V, D), device="cuda", dtype=dtype) 33 | 34 | c[0 : min(N, V) // 2] = e[0 : min(N, V) // 2] 35 | 36 | if has_bias: 37 | bias = torch.randn(V, device="cuda", dtype=dtype) * 0.02 38 | else: 39 | bias = None 40 | 41 | inds = torch.randint(0, V, size=(N,), device="cuda") 42 | 43 | gt = -(e.float() * c[inds].float()).sum(-1) 44 | if bias is not None: 45 | gt -= bias[inds].float() 46 | 47 | if softcap is not None: 48 | gt = softcapping(gt, softcap) 49 | 50 | ref = -(e * c[inds]).sum(-1, dtype=torch.float32) 51 | if bias is not None: 52 | ref -= bias[inds].float() 53 | 54 | if softcap is not None: 55 | ref = softcapping(ref, softcap) 56 | 57 | ref = ref.to(dtype=dtype) 58 | 59 | cce_neg_dot = indexed_neg_dot_forward_kernel(e, c, inds, bias=bias, softcap=softcap) 60 | 61 | expected_error = (gt - ref.float()).abs() 62 | cce_error = (gt - cce_neg_dot.float()).abs() 63 | 64 | assert ( 65 | cce_error <= (expected_error + error_tol) 66 | ).all(), f"{(cce_error - expected_error).relu().max()=}" 67 | -------------------------------------------------------------------------------- /tests/test_cce_loss_backward.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import pytest 3 | import torch 4 | 5 | from cut_cross_entropy import linear_cross_entropy 6 | from cut_cross_entropy.constants import IGNORE_INDEX 7 | from cut_cross_entropy.utils import softcapping 8 | 9 | skip_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA") 10 | 11 | 12 | def _grads( 13 | e: torch.Tensor, 14 | c: torch.Tensor, 15 | targets: torch.Tensor, 16 | bias: torch.Tensor | None, 17 | softcap: float | None, 18 | shift: bool, 19 | reduction: str, 20 | fp32: bool = False, 21 | ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 22 | orig_e, orig_c, orig_bias = e, c, bias 23 | if bias is not None: 24 | bias.grad = None 25 | e.grad = c.grad = None 26 | 27 | N, T = targets.size() 28 | if shift: 29 | e = e[:, :-1] 30 | targets = targets[:, 1:] 31 | T = T - 1 32 | 33 | e = e.flatten(0, -2) 34 | targets = targets.flatten() 35 | 36 | if fp32: 37 | e = e.float() 38 | c = c.float() 39 | bias = bias.float() if bias is not None else None 40 | 41 | logits = e @ c.T 42 | if bias is not None: 43 | logits += bias 44 | 45 | if softcap is not None: 46 | logits = softcapping(logits, softcap) 47 | 48 | loss = torch.nn.functional.cross_entropy( 49 | logits.float(), targets, ignore_index=IGNORE_INDEX, reduction=reduction 50 | ) 51 | 52 | if reduction == "sum": 53 | loss = loss / (targets != IGNORE_INDEX).count_nonzero() 54 | 55 | loss.mean().backward() 56 | 57 | assert orig_e.grad is not None 58 | assert orig_c.grad is not None 59 | 60 | if bias is not None: 61 | assert orig_bias is not None 62 | assert orig_bias.grad is not None 63 | return ( 64 | orig_e.grad.detach().clone(), 65 | orig_c.grad.detach().clone(), 66 | orig_bias.grad.detach().clone(), 67 | ) 68 | else: 69 | return orig_e.grad.detach().clone(), orig_c.grad.detach().clone() 70 | 71 | 72 | @skip_no_cuda 73 | @pytest.mark.parametrize("impl", ["cce", "torch_compile", "cce_exact"]) 74 | @pytest.mark.parametrize("dtype,error_tol", [(torch.float16, 1e-3), (torch.bfloat16, 1e-2)]) 75 | @pytest.mark.parametrize("softcap", [None, 20.0]) 76 | @pytest.mark.parametrize("has_bias", [False, True]) 77 | @pytest.mark.parametrize("shift", [False, True]) 78 | @pytest.mark.parametrize("invalids", [False, True]) 79 | @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) 80 | @pytest.mark.parametrize("shape", [(256, 512, 128), (252, 507, 128), (252, 507, 123)]) 81 | def test_loss_backward( 82 | impl: str, 83 | dtype: torch.dtype, 84 | error_tol: float, 85 | softcap: float | None, 86 | has_bias: bool, 87 | shift: bool, 88 | invalids: bool, 89 | reduction: str, 90 | shape: tuple[int, int, int], 91 | ): 92 | torch.set_float32_matmul_precision("highest") 93 | torch._dynamo.config.cache_size_limit = 256 94 | torch.cuda.manual_seed(0) 95 | 96 | if dtype == torch.bfloat16 and not torch.cuda.is_available(): 97 | pytest.skip(reason="BF16 not avaliable") 98 | 99 | N, V, D = shape 100 | e = torch.randn((N, D), device="cuda", dtype=dtype, requires_grad=False) / (D**0.5) 101 | c = torch.randn((V, D), device="cuda", dtype=dtype, requires_grad=False) 102 | 103 | c[0 : min(N, V) // 2] = e[0 : min(N, V) // 2] 104 | 105 | targets = torch.randint(0, V, size=(N,), device="cuda") 106 | 107 | if invalids: 108 | inds = torch.randperm(len(targets), device="cuda")[0 : int(0.2 * len(targets))] 109 | targets[inds] = IGNORE_INDEX 110 | 111 | e = e.view(4, -1, D) 112 | 113 | targets = targets.view(e.size()[0:-1]) 114 | 115 | if has_bias: 116 | bias = torch.randn(V, device="cuda", dtype=dtype) * 0.02 117 | bias.requires_grad_(True) 118 | else: 119 | bias = None 120 | 121 | e.requires_grad_(True) 122 | c.requires_grad_(True) 123 | 124 | gt = _grads(e, c, targets, bias, softcap, shift, reduction, fp32=True) 125 | 126 | ref = _grads(e, c, targets, bias, softcap, shift, reduction) 127 | 128 | e.grad = c.grad = None 129 | if bias is not None: 130 | bias.grad = None 131 | 132 | loss = linear_cross_entropy( 133 | e, c, targets, bias=bias, softcap=softcap, shift=shift, reduction=reduction, impl=impl 134 | ) 135 | if reduction == "sum": 136 | loss = loss / (targets != IGNORE_INDEX).count_nonzero() 137 | loss.mean().backward() 138 | assert e.grad is not None 139 | assert c.grad is not None 140 | 141 | if bias is not None: 142 | assert bias.grad is not None 143 | cce = (e.grad, c.grad, bias.grad) 144 | else: 145 | cce = (e.grad, c.grad) 146 | 147 | expected_error = tuple((vgt - vref).abs().flatten() for vgt, vref in zip(gt, ref, strict=True)) 148 | cce_error = tuple((vgt - vcce).abs().flatten() for vgt, vcce in zip(gt, cce, strict=True)) 149 | 150 | for i in range(len(expected_error)): 151 | if not (cce_error[i] <= (expected_error[i] + error_tol)).all(): 152 | errors = (cce_error[i] - expected_error[i]).relu() 153 | argmax_error = int(errors.argmax()) 154 | raise ValueError( 155 | f"{i=}, {errors.max()=}, {cce[i].flatten()[argmax_error]=}, " 156 | f"{gt[i].flatten()[argmax_error]=}, {ref[i].flatten()[argmax_error]=}" 157 | ) 158 | 159 | 160 | @skip_no_cuda 161 | @pytest.mark.parametrize( 162 | "compute_de,compute_dc,compute_dbias", 163 | [(True, False, True), (False, True, False), (False, False, True)], 164 | ) 165 | def test_loss_partials(compute_de: bool, compute_dc: bool, compute_dbias: bool): 166 | torch.cuda.manual_seed(0) 167 | dtype = torch.bfloat16 168 | 169 | N, V, D = (256, 512, 128) 170 | e = torch.randn((N, D), device="cuda", dtype=dtype, requires_grad=False) / (D**0.5) 171 | c = torch.randn((V, D), device="cuda", dtype=dtype, requires_grad=False) 172 | bias = torch.randn(V, device="cuda", dtype=dtype, requires_grad=False) * 0.01 173 | 174 | c[0 : min(N, V) // 2] = e[0 : min(N, V) // 2] 175 | 176 | targets = torch.randint(0, V, size=(N,), device="cuda") 177 | 178 | e = e.view(4, -1, D) 179 | targets = targets.view(e.size()[0:-1]) 180 | 181 | e.requires_grad_(compute_de) 182 | c.requires_grad_(compute_dc) 183 | bias.requires_grad_(compute_dbias) 184 | 185 | e.grad = c.grad = bias.grad = None 186 | loss = linear_cross_entropy(e, c, targets, bias=bias, reduction="mean") 187 | loss.backward() 188 | 189 | assert (e.grad is not None) == compute_de 190 | assert (c.grad is not None) == compute_dc 191 | assert (bias.grad is not None) == compute_dbias 192 | -------------------------------------------------------------------------------- /tests/test_cce_loss_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import pytest 3 | import torch 4 | 5 | from cut_cross_entropy import linear_cross_entropy 6 | from cut_cross_entropy.constants import IGNORE_INDEX 7 | from cut_cross_entropy.utils import softcapping 8 | 9 | skip_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA") 10 | 11 | 12 | def _loss( 13 | e: torch.Tensor, 14 | c: torch.Tensor, 15 | targets: torch.Tensor, 16 | bias: torch.Tensor | None, 17 | softcap: float | None, 18 | shift: int, 19 | ) -> torch.Tensor: 20 | N, T = targets.size() 21 | 22 | if shift != 0: 23 | e = e[:, :-shift] 24 | targets = targets[:, shift:] 25 | T = T - shift 26 | 27 | e = e.flatten(0, -2) 28 | targets = targets.flatten() 29 | 30 | logits = e @ c.T 31 | if bias is not None: 32 | logits += bias 33 | 34 | if softcap is not None: 35 | logits = softcapping(logits, softcap) 36 | 37 | loss = torch.nn.functional.cross_entropy( 38 | logits.float(), targets, ignore_index=IGNORE_INDEX, reduction="none" 39 | ) 40 | 41 | return loss.view(N, T) 42 | 43 | 44 | @skip_no_cuda 45 | @pytest.mark.parametrize("impl", ["cce", "torch_compile"]) 46 | @pytest.mark.parametrize( 47 | "dtype,error_tol", [(torch.float32, 1e-5), (torch.float16, 1e-3), (torch.bfloat16, 1e-2)] 48 | ) 49 | @pytest.mark.parametrize("softcap", [None, 20.0]) 50 | @pytest.mark.parametrize("has_bias", [True, False]) 51 | @pytest.mark.parametrize("shift", [0, 2]) 52 | @pytest.mark.parametrize("invalids", [False, True]) 53 | @pytest.mark.parametrize("shape", [(256, 512, 128), (252, 507, 128), (252, 507, 123)]) 54 | def test_loss_forward( 55 | impl: str, 56 | dtype: torch.dtype, 57 | error_tol: float, 58 | softcap: float | None, 59 | has_bias: bool, 60 | shift: int, 61 | invalids: bool, 62 | shape: tuple[int, int, int], 63 | ): 64 | torch.set_float32_matmul_precision("highest") 65 | torch._dynamo.config.cache_size_limit = 256 66 | torch.cuda.manual_seed(0) 67 | 68 | if dtype == torch.bfloat16 and not torch.cuda.is_available(): 69 | pytest.skip(reason="BF16 not avaliable") 70 | 71 | N, V, D = shape 72 | e = torch.randn((N, D), device="cuda", dtype=dtype) / (D**0.5) 73 | c = torch.randn((V, D), device="cuda", dtype=dtype) 74 | 75 | c[0 : min(N, V) // 2] = e[0 : min(N, V) // 2] 76 | 77 | if has_bias: 78 | bias = torch.randn(V, device="cuda", dtype=dtype) * 0.01 79 | else: 80 | bias = None 81 | 82 | e = e.view(4, -1, D) 83 | 84 | targets = torch.randint(0, V, size=(N,), device="cuda") 85 | 86 | if invalids: 87 | inds = torch.randperm(len(targets), device="cuda")[0 : int(0.2 * len(targets))] 88 | targets[inds] = IGNORE_INDEX 89 | 90 | targets = targets.view(e.size()[0:-1]) 91 | 92 | gt = _loss( 93 | e.float(), c.float(), targets, bias.float() if bias is not None else None, softcap, shift 94 | ) 95 | 96 | torch.set_float32_matmul_precision("highest" if dtype == torch.float32 else "high") 97 | ref = _loss(e, c, targets, bias, softcap, shift) 98 | 99 | cce_loss = linear_cross_entropy( 100 | e, c, targets, bias=bias, softcap=softcap, shift=shift, reduction="none", impl=impl 101 | ) 102 | 103 | expected_error = (gt - ref).abs() 104 | cce_error = (gt - cce_loss).abs() 105 | 106 | assert ( 107 | cce_error <= (expected_error + error_tol) 108 | ).all(), f"{(cce_error - expected_error).relu().max()=}" 109 | -------------------------------------------------------------------------------- /tests/test_cce_lse.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import pytest 3 | import torch 4 | 5 | from cut_cross_entropy.cce_lse_forward import cce_lse_forward_kernel 6 | from cut_cross_entropy.utils import softcapping 7 | 8 | skip_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA") 9 | 10 | 11 | def _lse( 12 | e: torch.Tensor, c: torch.Tensor, bias: torch.Tensor | None, softcap: float | None 13 | ) -> torch.Tensor: 14 | logits = e @ c.T 15 | if bias is not None: 16 | logits += bias 17 | 18 | if softcap is not None: 19 | logits = softcapping(logits, softcap) 20 | return torch.logsumexp(logits.float(), dim=-1) 21 | 22 | 23 | @skip_no_cuda 24 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) 25 | @pytest.mark.parametrize("softcap", [None, 20.0]) 26 | @pytest.mark.parametrize("has_bias", [True, False]) 27 | @pytest.mark.parametrize("shape", [(256, 512, 128), (255, 507, 128), (255, 507, 123)]) 28 | def test_lse( 29 | dtype: torch.dtype, 30 | softcap: float | None, 31 | has_bias: bool, 32 | shape: tuple[int, int, int], 33 | ): 34 | torch.set_float32_matmul_precision("highest") 35 | torch.cuda.manual_seed(0) 36 | 37 | if dtype == torch.bfloat16 and not torch.cuda.is_available(): 38 | pytest.skip(reason="BF16 not avaliable") 39 | 40 | N, V, D = shape 41 | e = torch.randn((N, D), device="cuda", dtype=dtype) / (D**0.5) 42 | c = torch.randn((V, D), device="cuda", dtype=dtype) 43 | 44 | c[0 : min(N, V) // 2] = e[0 : min(N, V) // 2] 45 | 46 | if has_bias: 47 | bias = torch.randn(V, device="cuda", dtype=dtype) * 0.02 48 | else: 49 | bias = None 50 | 51 | gt = _lse(e.float(), c.float(), bias.float() if bias is not None else None, softcap) 52 | 53 | torch.set_float32_matmul_precision("highest" if dtype == torch.float32 else "high") 54 | ref = _lse(e, c, bias, softcap) 55 | 56 | cce_lse = cce_lse_forward_kernel(e, c, bias, softcap=softcap) 57 | 58 | expected_error = (gt - ref).abs() 59 | cce_error = (gt - cce_lse).abs() 60 | 61 | assert ( 62 | cce_error <= (expected_error + 1e-5) 63 | ).all(), f"{(cce_error - expected_error).relu().max()=}" 64 | -------------------------------------------------------------------------------- /tests/test_vocab_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import contextlib 3 | import socket 4 | 5 | import pytest 6 | import torch 7 | import torch.distributed 8 | from torch.multiprocessing.spawn import spawn as mp_spawn 9 | 10 | from cut_cross_entropy import VocabParallelOptions, linear_cross_entropy 11 | from cut_cross_entropy.constants import IGNORE_INDEX 12 | from cut_cross_entropy.vocab_parallel.utils import partition_n_into_range 13 | 14 | 15 | def find_free_port() -> int: 16 | """ 17 | Returns a free port on the system. 18 | """ 19 | with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: 20 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 21 | sock.bind(("localhost", 0)) 22 | _, port = sock.getsockname() 23 | return port 24 | 25 | 26 | def _target_fn_test_vp( 27 | rank: int, 28 | world_size: int, 29 | port: int, 30 | impl: str, 31 | dtype: torch.dtype, 32 | error_tol: float, 33 | invalids: bool, 34 | ): 35 | device = ( 36 | torch.device("cpu") 37 | if not torch.cuda.is_available() 38 | else torch.device("cuda", rank % torch.cuda.device_count()) 39 | ) 40 | 41 | if device.type == "cuda": 42 | torch.cuda.set_device(device) 43 | backend = "cpu:gloo,cuda:nccl" 44 | else: 45 | backend = "gloo" 46 | 47 | store = torch.distributed.TCPStore( 48 | "localhost", port, world_size=world_size, is_master=rank == 0 49 | ) 50 | 51 | torch.distributed.init_process_group( 52 | backend=backend, store=store, world_size=world_size, rank=rank 53 | ) 54 | 55 | N, V, D = (252, 507, 123) 56 | 57 | e = torch.randn((N, D), device=device, dtype=dtype) / (D**0.5) 58 | c = torch.randn((V, D), device=device, dtype=dtype) 59 | 60 | targets = torch.randint(0, V, size=(N,), device=device) 61 | if invalids: 62 | inds = torch.randperm(len(targets), device=device)[0 : int(0.2 * len(targets))] 63 | targets[inds] = IGNORE_INDEX 64 | 65 | e = e.view(4, -1, D) 66 | targets = targets.view(e.size()[0:-1]) 67 | 68 | torch.distributed.broadcast(e, src=0) 69 | torch.distributed.broadcast(c, src=0) 70 | torch.distributed.broadcast(targets, src=0) 71 | 72 | vocab_parallel_options = VocabParallelOptions.from_vocab(V) 73 | 74 | vp_c = c[vocab_parallel_options.start : vocab_parallel_options.stop].clone() 75 | 76 | vp_c.requires_grad_(True) 77 | e.requires_grad_(True) 78 | vp_loss = linear_cross_entropy( 79 | e, vp_c, targets, impl=impl, vocab_parallel_options=vocab_parallel_options 80 | ) 81 | vp_loss.backward() 82 | 83 | assert e.grad is not None 84 | vp_e_grad = e.grad.clone() 85 | e.grad = None 86 | 87 | c.requires_grad_(True) 88 | loss = linear_cross_entropy(e, c, targets, impl=impl) 89 | loss.backward() 90 | 91 | assert c.grad is not None 92 | assert vp_c.grad is not None 93 | assert torch.allclose( 94 | c.grad[vocab_parallel_options.start : vocab_parallel_options.stop], 95 | vp_c.grad, 96 | atol=error_tol, 97 | ), "c grad not close" 98 | 99 | assert e.grad is not None 100 | assert torch.allclose( 101 | e.grad, 102 | vp_e_grad, 103 | atol=error_tol, 104 | ), f"{(e.grad - vp_e_grad).abs().max().item()=}" 105 | 106 | 107 | @pytest.mark.parametrize("impl", ["torch_compile", "cce_exact"]) 108 | @pytest.mark.parametrize("dtype,error_tol", [(torch.float16, 1e-3), (torch.bfloat16, 1e-2)]) 109 | @pytest.mark.parametrize("nprocs", [4]) 110 | @pytest.mark.parametrize("invalids", [False, True]) 111 | def test_vocab_parallel( 112 | impl: str, dtype: torch.dtype, error_tol: float, nprocs: int, invalids: bool 113 | ): 114 | if impl == "cce" and not torch.cuda.is_available(): 115 | pytest.skip("Testing vocab parallel CCE requires cuda") 116 | 117 | mp_spawn( 118 | _target_fn_test_vp, 119 | args=(nprocs, find_free_port(), impl, dtype, error_tol, invalids), 120 | nprocs=nprocs, 121 | join=True, 122 | ) 123 | 124 | 125 | @pytest.mark.parametrize("n", [1023, 2048]) 126 | @pytest.mark.parametrize("world_size", [7, 8]) 127 | def test_partition_n_into_range(n: int, world_size: int): 128 | start = 0 129 | for rank in range(world_size): 130 | end = start + n // world_size + (1 if rank < (n % world_size) else 0) 131 | 132 | assert partition_n_into_range(n, rank, world_size) == (start, end) 133 | 134 | start = end 135 | 136 | assert end == n 137 | assert partition_n_into_range(n, world_size - 1, world_size)[1] == n 138 | -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | import os 3 | import subprocess 4 | import time 5 | from dataclasses import dataclass, field 6 | from pathlib import Path 7 | from typing import Any, Sequence, cast 8 | 9 | import datasets 10 | import torch 11 | import torch.distributed 12 | import transformers 13 | from torch.utils.data import Dataset 14 | from transformers.trainer import EvalPrediction 15 | 16 | from cut_cross_entropy.transformers import cce_patch 17 | 18 | IGNORE_INDEX = -100 19 | SYSTEM_PROMPT = "You are a helpful AI assistant." 20 | PROMPT_DICT = { 21 | "prompt_input": ( 22 | "Below is an instruction that describes a task, paired with an input that provides further context. " 23 | "Write a response that appropriately completes the request.\n\n" 24 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n" 25 | ), 26 | "prompt_no_input": ( 27 | "Below is an instruction that describes a task. " 28 | "Write a response that appropriately completes the request.\n\n" 29 | "### Instruction:\n{instruction}\n\n" 30 | ), 31 | } 32 | 33 | MODEL_NAME_MAP = { 34 | "gemma2": "google/gemma-2-2b-it", 35 | "phi3.5": "microsoft/Phi-3.5-mini-instruct", 36 | "llama3": "meta-llama/Meta-Llama-3-8B-Instruct", 37 | "mistral-nemo": "mistralai/Mistral-Nemo-Instruct-2407", 38 | "qwen2.5": "Qwen/Qwen2.5-7B-Instruct", 39 | } 40 | 41 | DATA_NAME_MAP = {"alpaca": "yahma/alpaca-cleaned", "open-webtext": "Skylion007/openwebtext"} 42 | 43 | 44 | @dataclass 45 | class ModelArguments: 46 | model_name: str 47 | attn_impl: str | None = None 48 | cross_entropy_impl: str = "cce" 49 | 50 | 51 | @dataclass 52 | class DataArguments: 53 | dataset_name: str = "alpaca" 54 | sequence_length: int = 512 55 | 56 | 57 | @dataclass 58 | class TrainingArguments(transformers.TrainingArguments): 59 | remove_unused_columns: bool = False 60 | torch_compile: bool = False 61 | fp16: bool = False 62 | bf16: bool = True 63 | tf32: bool = True 64 | gradient_checkpoint: bool = True 65 | logging_strategy: str = "steps" 66 | logging_steps: int = 1 67 | warmup_ratio: float = 0.05 68 | dataloader_num_workers: int = 12 69 | dataloader_pin_memory: bool = True 70 | save_strategy: str = "no" 71 | save_steps: int = 400 72 | save_total_limit: int = 3 73 | num_train_epochs: float = 1.0 74 | gradient_checkpoint_kwargs: dict[str, Any] = field( 75 | default_factory=lambda: dict(use_reentrant=True) 76 | ) 77 | 78 | 79 | def download_hf(name: str, repo_type: str = "model"): 80 | if not Path(name).exists(): 81 | for i in range(10): 82 | try: 83 | subprocess.check_call( 84 | [ 85 | "huggingface-cli", 86 | "download", 87 | "--exclude=original/*", 88 | f"--repo-type={repo_type}", 89 | name, 90 | ] 91 | ) 92 | except Exception as e: 93 | if i == 9: 94 | raise e 95 | else: 96 | break 97 | 98 | time.sleep(1) 99 | 100 | 101 | def preprocess( 102 | source: str, 103 | target: str, 104 | tokenizer: transformers.PreTrainedTokenizer, 105 | uses_system_prompt: bool = True, 106 | ) -> dict: 107 | """Preprocess the data by tokenizing.""" 108 | if uses_system_prompt: 109 | messages = [ 110 | {"role": "system", "content": SYSTEM_PROMPT}, 111 | ] 112 | else: 113 | messages = [] 114 | 115 | messages.extend( 116 | ( 117 | {"role": "user", "content": source}, 118 | {"role": "assistant", "content": target}, 119 | ) 120 | ) 121 | tokenization = tokenizer.apply_chat_template( 122 | messages, 123 | add_generation_prompt=False, 124 | return_dict=True, 125 | ) 126 | input_ids = torch.as_tensor(tokenization["input_ids"]) 127 | 128 | target_ids = tokenizer.encode(target, add_special_tokens=False, return_tensors="pt")[0] 129 | 130 | labels = input_ids.clone() 131 | for offset in reversed(range(0, len(input_ids) - len(target_ids))): 132 | if (labels[offset : offset + len(target_ids)] == target_ids).all(): 133 | labels[0:offset] = IGNORE_INDEX 134 | break 135 | 136 | return dict(input_ids=input_ids, labels=labels) 137 | 138 | 139 | class SupervisedDataset(Dataset): 140 | """Dataset for supervised fine-tuning.""" 141 | 142 | def __init__( 143 | self, 144 | data_args: DataArguments, 145 | seed: int, 146 | tokenizer: transformers.PreTrainedTokenizer, 147 | split: str = "train", 148 | uses_system_prompt: bool = True, 149 | ): 150 | super().__init__() 151 | self.dataset = datasets.load_dataset(data_args.dataset_name, split="train") 152 | self.tokenizer = tokenizer 153 | self.uses_system_prompt = uses_system_prompt 154 | 155 | def __len__(self): 156 | return len(self.dataset) 157 | 158 | def __getitem__(self, i) -> dict[str, torch.Tensor]: 159 | element = self.dataset[i] 160 | if element["input"] == "": 161 | prompt_template = PROMPT_DICT["prompt_no_input"] 162 | else: 163 | prompt_template = PROMPT_DICT["prompt_input"] 164 | 165 | source = prompt_template.format_map(element) 166 | target = element["output"] 167 | 168 | return preprocess(source, target, self.tokenizer, self.uses_system_prompt) 169 | 170 | 171 | @dataclass 172 | class DataCollatorForSupervisedDataset: 173 | """Collate examples for supervised fine-tuning.""" 174 | 175 | pad_token_id: int | None 176 | padding_side: str 177 | 178 | def __call__(self, instances: Sequence[dict]) -> dict[str, torch.Tensor | None]: 179 | input_ids, labels = tuple( 180 | [instance[key] for instance in instances] for key in ("input_ids", "labels") 181 | ) 182 | max_len = max(len(v) for v in input_ids) 183 | assert self.pad_token_id is not None 184 | padded_input_ids = torch.full((len(input_ids), max_len), self.pad_token_id) 185 | padded_labels = torch.full((len(input_ids), max_len), IGNORE_INDEX) 186 | position_ids = torch.zeros((len(input_ids), max_len), dtype=torch.int64) 187 | attention_mask = torch.full((len(input_ids), max_len), False, dtype=torch.bool) 188 | 189 | for i, (inp, lbl) in enumerate(zip(input_ids, labels, strict=True)): 190 | if self.padding_side == "right": 191 | slc = slice(len(inp)) 192 | else: 193 | slc = slice(-len(inp), None) 194 | 195 | padded_input_ids[i, slc] = inp 196 | padded_labels[i, slc] = lbl 197 | position_ids[i, slc] = torch.arange(len(inp), dtype=position_ids.dtype) 198 | attention_mask[i, slc] = True 199 | 200 | return dict( 201 | input_ids=padded_input_ids, 202 | labels=padded_labels, 203 | attention_mask=attention_mask, 204 | position_ids=position_ids, 205 | ) 206 | 207 | 208 | class PretrainDataset(Dataset): 209 | """Dataset for pretraining.""" 210 | 211 | def __init__( 212 | self, 213 | data_args: DataArguments, 214 | seed: int, 215 | tokenizer: transformers.PreTrainedTokenizer, 216 | split: str = "train", 217 | seq_len: int = 512, 218 | ): 219 | super().__init__() 220 | 221 | if torch.distributed.get_rank() == 0: 222 | train_on_percent = 5 223 | eval_on_percent = 0.25 224 | dataset = datasets.load_dataset( 225 | data_args.dataset_name, 226 | split="train", 227 | num_proc=48, 228 | trust_remote_code=True, 229 | ).train_test_split( 230 | train_size=(train_on_percent + eval_on_percent) / 100, 231 | seed=seed, 232 | shuffle=True, 233 | )["train"] 234 | 235 | dataset = dataset.train_test_split( 236 | train_size=train_on_percent / (train_on_percent + eval_on_percent), 237 | seed=seed, 238 | shuffle=True, 239 | )[split] 240 | 241 | encoded_text = tokenizer( 242 | list(example["text"] for example in dataset), 243 | add_special_tokens=False, 244 | padding=False, 245 | truncation=False, 246 | ).input_ids 247 | all_ids = [] 248 | assert tokenizer.bos_token_id is not None or tokenizer.eos_token_id is not None 249 | for e in encoded_text: 250 | if tokenizer.bos_token_id is not None: 251 | all_ids.append(tokenizer.bos_token_id) 252 | all_ids.extend(e) 253 | if tokenizer.eos_token_id is not None: 254 | all_ids.append(tokenizer.eos_token_id) 255 | 256 | all_ids_l = [all_ids] 257 | else: 258 | all_ids_l = [None] 259 | 260 | torch.distributed.broadcast_object_list(all_ids_l) 261 | 262 | assert all_ids_l[0] is not None 263 | self.all_input_ids = all_ids_l[0] 264 | 265 | self.seq_len = seq_len 266 | 267 | def __len__(self): 268 | return len(self.all_input_ids) // self.seq_len 269 | 270 | def __getitem__(self, i) -> dict[str, torch.Tensor]: 271 | start = i * self.seq_len 272 | seq = torch.as_tensor(self.all_input_ids[start : start + self.seq_len]) 273 | return dict(input_ids=seq, labels=seq) 274 | 275 | 276 | def make_supervised_data_module( 277 | tokenizer: transformers.PreTrainedTokenizer, 278 | data_args, 279 | seed, 280 | uses_system_prompt: bool = True, 281 | ) -> dict: 282 | """Make dataset and collator for supervised fine-tuning.""" 283 | train_dataset = SupervisedDataset( 284 | data_args, 285 | seed=seed, 286 | tokenizer=tokenizer, 287 | uses_system_prompt=uses_system_prompt, 288 | ) 289 | data_collator = DataCollatorForSupervisedDataset(tokenizer.pad_token_id, tokenizer.padding_side) 290 | return dict( 291 | train_dataset=train_dataset, 292 | eval_dataset=None, 293 | data_collator=data_collator, 294 | ) 295 | 296 | 297 | def make_pretrain_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args, seed) -> dict: 298 | train_dataset = PretrainDataset(data_args, seed=seed, tokenizer=tokenizer) 299 | eval_dataset = PretrainDataset(data_args, seed=seed, tokenizer=tokenizer, split="test") 300 | os.environ["TOKENIZERS_PARALLELISM"] = "False" 301 | 302 | data_collator = DataCollatorForSupervisedDataset(tokenizer.pad_token_id, tokenizer.padding_side) 303 | 304 | return dict( 305 | train_dataset=train_dataset, 306 | eval_dataset=eval_dataset, 307 | data_collator=data_collator, 308 | ) 309 | 310 | 311 | @torch.compile(dynamic=True) 312 | def _compute_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 313 | labels = labels.flatten() 314 | logits = logits.float().flatten(0, -2) 315 | nll = torch.nn.functional.cross_entropy( 316 | logits, 317 | labels, 318 | ignore_index=IGNORE_INDEX, 319 | reduction="none", 320 | ) 321 | nll[labels == IGNORE_INDEX] = 0.0 322 | return nll 323 | 324 | 325 | def _compute_ppl(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 326 | nll = _compute_loss(logits, labels) 327 | 328 | return torch.exp(nll.mean(-1) * (nll.size(-1) / (labels != IGNORE_INDEX).count_nonzero(-1))) 329 | 330 | 331 | @dataclass 332 | class MetricReducer: 333 | _val: float | torch.Tensor = 0.0 334 | _counter: int = 0 335 | 336 | @torch.no_grad() 337 | def add(self, v: torch.Tensor): 338 | if v.numel() > 0: 339 | self._val = self._val + v.detach().sum() 340 | self._counter += v.numel() 341 | 342 | @property 343 | def value(self) -> float: 344 | return float(self._val) / max(self._counter, 1) 345 | 346 | def reset(self): 347 | self._val = 0 348 | self._counter = 0 349 | 350 | 351 | @dataclass 352 | class Metrics: 353 | ppl: MetricReducer = field(default_factory=MetricReducer) 354 | 355 | @torch.inference_mode() 356 | def __call__(self, eval_pred: EvalPrediction, compute_result: bool) -> dict[str, float]: 357 | logits = torch.as_tensor(eval_pred.predictions[..., :-1, :]) 358 | labels = torch.as_tensor(eval_pred.label_ids[..., 1:]) 359 | 360 | ppl = _compute_ppl(logits, labels) 361 | 362 | self.ppl.add(ppl) 363 | 364 | res = { 365 | "perplexity": self.ppl.value, 366 | } 367 | 368 | if compute_result: 369 | self.ppl.reset() 370 | 371 | return res 372 | 373 | 374 | def main(): 375 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 376 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 377 | 378 | training_args = cast(TrainingArguments, training_args) 379 | model_args = cast(ModelArguments, model_args) 380 | data_args = cast(DataArguments, data_args) 381 | 382 | if model_args.model_name in MODEL_NAME_MAP: 383 | model_args.model_name = MODEL_NAME_MAP[model_args.model_name] 384 | 385 | if data_args.dataset_name in DATA_NAME_MAP: 386 | data_args.dataset_name = DATA_NAME_MAP[data_args.dataset_name] 387 | 388 | if torch.distributed.is_initialized(): 389 | if training_args.local_rank == 0: 390 | download_hf(model_args.model_name) 391 | download_hf(data_args.dataset_name, "dataset") 392 | 393 | torch.distributed.barrier() 394 | 395 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name, use_fast=True) 396 | config = transformers.AutoConfig.from_pretrained(model_args.model_name) 397 | 398 | if config.model_type == "mistral": 399 | tokenizer.padding_side = "left" 400 | tokenizer.pad_token = "" 401 | elif config.model_type == "llama": 402 | tokenizer.pad_token = "<|reserved_special_token_0|>" 403 | elif config.model_type == "phi3": 404 | tokenizer.eos_token = "<|end|>" 405 | 406 | attn_impl = model_args.attn_impl 407 | if attn_impl is None: 408 | if config.model_type == "gemma2": 409 | attn_impl = "eager" 410 | else: 411 | attn_impl = "flash_attention_2" 412 | 413 | # This could be done instead. That will patch transformers code globally 414 | # cce_patch(config, model_args.cross_entropy_impl) 415 | 416 | is_finetune = "alpaca" in data_args.dataset_name 417 | 418 | if is_finetune: 419 | model = transformers.AutoModelForCausalLM.from_pretrained( 420 | model_args.model_name, 421 | attn_implementation=attn_impl, 422 | torch_dtype=torch.bfloat16, 423 | ) 424 | else: 425 | model = transformers.AutoModelForCausalLM.from_config( 426 | config, 427 | attn_implementation=attn_impl, 428 | torch_dtype=torch.bfloat16, 429 | ) 430 | 431 | device = torch.device("cuda", torch.cuda.current_device()) 432 | model = model.to(device) 433 | 434 | model = cast(transformers.PreTrainedModel, model) 435 | 436 | model = cce_patch(model, model_args.cross_entropy_impl, train_only=True) 437 | 438 | if is_finetune: 439 | data_module = make_supervised_data_module( 440 | tokenizer, 441 | data_args, 442 | training_args.seed, 443 | uses_system_prompt=config.model_type not in ("gemma2",), 444 | ) 445 | compute_metrics = None 446 | else: 447 | data_module = make_pretrain_data_module( 448 | tokenizer, 449 | data_args, 450 | training_args.seed, 451 | ) 452 | training_args.batch_eval_metrics = True 453 | compute_metrics = Metrics() 454 | 455 | os.environ["TOKENIZERS_PARALLELISM"] = "False" 456 | 457 | trainer = transformers.Trainer( 458 | model, 459 | training_args, 460 | processing_class=tokenizer, 461 | compute_metrics=compute_metrics, 462 | **data_module, 463 | ) 464 | 465 | trainer.train() 466 | 467 | if data_module.get("eval_dataset") is not None: 468 | trainer.evaluate() 469 | 470 | 471 | if __name__ == "__main__": 472 | main() 473 | -------------------------------------------------------------------------------- /training/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": 5e8, 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } 29 | --------------------------------------------------------------------------------