├── .gitignore ├── LICENSE ├── README.md ├── fastckpt ├── __init__.py ├── llama_flash_attn_ckpt_monkey_patch.py └── llama_flash_attn_monkey_patch.py ├── pyproject.toml └── tests ├── README.md └── test_numerical_difference.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | .venv 7 | 8 | # Log 9 | *.log 10 | *.log.* 11 | *.json 12 | !playground/deepspeed_config_s2.json 13 | !playground/deepspeed_config_s3.json 14 | 15 | # Editor 16 | .idea 17 | *.swp 18 | 19 | # Other 20 | .DS_Store 21 | wandb 22 | output 23 | checkpoints_flant5_3b 24 | 25 | # Data 26 | *.pkl 27 | *.csv 28 | tests/state_of_the_union.txt 29 | 30 | # Build 31 | build -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastCkpt: accelerate your LLM training in one line! 2 | 3 | Fast gradient checkpoint is designed for accelerate the training with memory-efficient attention like FlashAttention and LightSeq. FastCkpt has monkey patch for both rematerialization-aware checkpointing and FlashAttention, so you can patch both in only one line! 4 | 5 | Paper: https://arxiv.org/pdf/2310.03294.pdf 6 | 7 | ## News 8 | - [2023/10] FastCkpt now supports LlamaModel in Huggingface! 9 | 10 | ## Install 11 | ```bash 12 | pip install fastckpt 13 | ``` 14 | 15 | ## Usage 16 | FastCkpt now supports HF training pipeline. 17 | 18 | ### Use FaskCkpt and FlashAttention 19 | To use `fasckpt` with `flash_attn`, import and run `replace_hf_ckpt_with_fast_ckpt` *before* importing `transformers` 20 | ```python 21 | # add monkey patch for fastckpt 22 | from fastckpt.llama_flash_attn_ckpt_monkey_patch import replace_hf_ckpt_with_fast_ckpt 23 | replace_hf_ckpt_with_fast_ckpt() 24 | 25 | # import transformers and other packages 26 | import transformers 27 | ... 28 | ``` 29 | 30 | ### Use FlashAttention only 31 | To only replace the `LlamaAttention` with `flash_attn` without chaning the checkpointing strategy, import and run `replace_llama_attn_with_flash_attn` 32 | 33 | ```python 34 | # add monkey patch for fastckpt 35 | from fastckpt.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 36 | replace_llama_attn_with_flash_attn() 37 | 38 | # import transformers and other packages 39 | import transformers 40 | ... 41 | ``` 42 | 43 | If you find this repo useful, please cite 44 | ``` 45 | @article{li2023lightseq, 46 | title={LIGHTSEQ: SEQUENCE LEVEL PARALLELISM FOR DISTRIBUTED TRAINING OF LONG CONTEXT TRANS}, 47 | author={Li, Dacheng and Shao, Rulin and Xie𝑠, Anze and Xing𝑐𝑚, Eric P and Gonzalez𝑏, Joseph E and Stoica𝑏, Ion and Ma𝑢, Xuezhe and Zhang𝑠, Hao}, 48 | journal={arXiv preprint arXiv:2310.03294}, 49 | year={2023} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /fastckpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RulinShao/FastCkpt/96fb5bb8fe0a54b27cb17ebb1f5add61011e914d/fastckpt/__init__.py -------------------------------------------------------------------------------- /fastckpt/llama_flash_attn_ckpt_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | from einops import rearrange 3 | 4 | import torch 5 | from torch import nn 6 | from torch.utils.checkpoint import _get_autocast_kwargs, check_backward_validity, get_device_states, set_device_states, detach_variable 7 | 8 | import transformers 9 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, BaseModelOutputWithPast 10 | 11 | from flash_attn.flash_attn_interface import _flash_attn_varlen_forward, _flash_attn_varlen_backward 12 | from flash_attn.bert_padding import unpad_input, pad_input 13 | 14 | 15 | # define a global buffer to save flash attention outputs 16 | # it's called global because it saves the outputs for all layers 17 | global_flash_attn_out_buffer = None 18 | 19 | # define a local buffer to save recomputed qkv 20 | # it's called local because it's a temporary buffer which will be updated across layers 21 | local_res_grad_buffer = None 22 | 23 | # hooks for the gradients of residual 24 | global_hooks = [] 25 | 26 | def init_flash_attn_buffers(num_layers): 27 | # update the global buffer according to number of layers 28 | global global_flash_attn_out_buffer 29 | global_flash_attn_out_buffer = [None] * num_layers 30 | 31 | def clean_hook(): 32 | # Remove all hooks in the global buffer 33 | for hook in global_hooks: 34 | hook.remove() 35 | # Clear the global buffer 36 | global_hooks.clear() 37 | 38 | def clear_all_buffers_at_the_end_of_training(): 39 | # call it at the end of training 40 | global lobal_flash_attn_out_buffer 41 | global_flash_attn_out_buffer = None 42 | global local_res_grad_buffer 43 | local_res_grad_buffer = None 44 | clean_hook() 45 | 46 | def save_flash_attn_out_to_global_buffer(idx, out): 47 | global global_flash_attn_out_buffer 48 | global_flash_attn_out_buffer[idx] = out 49 | 50 | def get_flash_attn_out_from_global_buffer(idx): 51 | global global_flash_attn_out_buffer 52 | return global_flash_attn_out_buffer[idx] 53 | 54 | def free_flash_attn_out_buffer(idx): 55 | global global_flash_attn_out_buffer 56 | global_flash_attn_out_buffer[idx] = None 57 | 58 | def write_gradient_to_flash_attn_out(idx, grad): 59 | global global_flash_attn_out_buffer 60 | global_flash_attn_out_buffer[idx].grad = grad 61 | 62 | def save_res_grad_hook(grad): 63 | global local_res_grad_buffer 64 | local_res_grad_buffer = grad 65 | 66 | def load_and_add_res_grad_hook(grad): 67 | grad += get_res_grad_from_local_buffer() 68 | 69 | def get_res_grad_from_local_buffer(): 70 | global local_res_grad_buffer 71 | assert local_res_grad_buffer is not None 72 | return local_res_grad_buffer 73 | 74 | class CheckpointFunctionEndWithFlashAttention(torch.autograd.Function): 75 | """ Avoid doing twice flash attention forward during checkpointed backward. 76 | args: 77 | hidden_states, # i.e., flash attention output which is saved in global buffer. 78 | attention_mask, 79 | position_ids, 80 | residual, # the gradient of residual is saved in local buffer to pass across ckpt layers. 81 | """ 82 | 83 | @staticmethod 84 | def forward(ctx, run_function, layer_idx, preserve_rng_state, *args): 85 | check_backward_validity(args) 86 | ctx.run_function = run_function 87 | ctx.layer_idx = layer_idx 88 | ctx.preserve_rng_state = preserve_rng_state 89 | # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. 90 | ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() 91 | if preserve_rng_state: 92 | ctx.fwd_cpu_state = torch.get_rng_state() 93 | # Don't eagerly initialize the cuda context by accident. 94 | # (If the user intends that the context is initialized later, within their 95 | # run_function, we SHOULD actually stash the cuda state here. Unfortunately, 96 | # we have no way to anticipate this will happen before we run the function.) 97 | ctx.had_cuda_in_fwd = False 98 | if torch.cuda._initialized: 99 | ctx.had_cuda_in_fwd = True 100 | ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) 101 | 102 | # Save non-tensor inputs in ctx, keep a placeholder None for tensors 103 | # to be filled out during the backward. 104 | ctx.inputs = [] 105 | ctx.tensor_indices = [] 106 | tensor_inputs = [] 107 | for i, arg in enumerate(args): 108 | if i == 0 and ctx.layer_idx != 0: 109 | # flash attention output is saved to the global buffer during forward 110 | ctx.inputs.append(None) 111 | else: 112 | if torch.is_tensor(arg): 113 | tensor_inputs.append(arg) 114 | ctx.tensor_indices.append(i) 115 | ctx.inputs.append(None) 116 | else: 117 | ctx.inputs.append(arg) 118 | 119 | with torch.no_grad(): 120 | # --- modules before flash attention --- 121 | query_states, key_states, value_states, attention_mask, residual = run_function(*args) 122 | 123 | # --- prepare for flash attention --- 124 | bsz, q_len, _ = residual.size() 125 | qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] 126 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 127 | 128 | # unpad 129 | if attention_mask is None: 130 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 131 | max_s = q_len 132 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, 133 | device=qkv.device) 134 | indices = None 135 | else: 136 | nheads = qkv.shape[-2] 137 | qkv = rearrange(qkv, 'b s three h d -> b s (three h d)') 138 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, attention_mask) 139 | qkv = rearrange(qkv, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) 140 | 141 | # --- compute flash attention --- 142 | softmax_scale = qkv.shape[-1] ** (-0.5) 143 | out, _, _, _, _, softmax_lse, _, _ = _flash_attn_varlen_forward( 144 | qkv[:, 0], 145 | qkv[:, 1], 146 | qkv[:, 2], 147 | cu_q_lens, 148 | cu_q_lens, 149 | max_s, 150 | max_s, 151 | 0.0, 152 | softmax_scale, 153 | causal=True, 154 | return_softmax=False, 155 | ) 156 | 157 | # save flash attention output to global buffer 158 | save_flash_attn_out_to_global_buffer(ctx.layer_idx, out) 159 | tensor_inputs += [softmax_lse] 160 | ctx.softmax_scale = softmax_scale 161 | 162 | ctx.save_for_backward(*tensor_inputs) 163 | 164 | return out, residual, indices 165 | 166 | @staticmethod 167 | def backward(ctx, *args): 168 | if not torch.autograd._is_checkpoint_valid(): 169 | raise RuntimeError( 170 | "Checkpointing is not compatible with .grad() or when an `inputs` parameter" 171 | " is passed to .backward(). Please use .backward() and do not pass its `inputs`" 172 | " argument.") 173 | # Copy the list to avoid modifying original list. 174 | inputs = list(ctx.inputs) 175 | tensor_indices = ctx.tensor_indices 176 | tensors = ctx.saved_tensors 177 | tensors, softmax_lse = tensors[:-1], tensors[-1] 178 | 179 | # Fill in inputs with appropriate saved tensors. 180 | # Fill the flash attention output first 181 | if ctx.layer_idx > 0: 182 | # inputs[0] should be flash attention output 183 | inputs[0] = get_flash_attn_out_from_global_buffer(ctx.layer_idx-1) 184 | for i, idx in enumerate(tensor_indices): 185 | inputs[idx] = tensors[i] 186 | 187 | # Stash the surrounding rng state, and mimic the state that was 188 | # present at this time during forward. Restore the surrounding state 189 | # when we're done. 190 | rng_devices = [] 191 | if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: 192 | rng_devices = ctx.fwd_gpu_devices 193 | with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): 194 | if ctx.preserve_rng_state: 195 | torch.set_rng_state(ctx.fwd_cpu_state) 196 | if ctx.had_cuda_in_fwd: 197 | set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) 198 | detached_inputs = detach_variable(tuple(inputs)) 199 | with torch.enable_grad(), \ 200 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ 201 | torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): 202 | # Stop recomputation before flash attention 203 | # It is unecessary to run recomputation for flash attn 204 | query_states, key_states, value_states, attention_mask, residual = ctx.run_function(*detached_inputs) 205 | 206 | # --- prepare for flash attention --- 207 | bsz, q_len, _ = residual.size() 208 | qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] 209 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 210 | 211 | # unpad 212 | if attention_mask is None: 213 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 214 | max_s = q_len 215 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, 216 | device=qkv.device) 217 | indices = None 218 | else: 219 | nheads = qkv.shape[-2] 220 | qkv = rearrange(qkv, 'b s three h d -> b s (three h d)') 221 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, attention_mask) 222 | qkv = rearrange(qkv, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) 223 | 224 | # run backward() with only tensor that requires grad 225 | # run flash attention backward first: 226 | # get 'dout' from auto_grad inputs 227 | # get 'out' from global buffer 228 | # get 'qkv' from the recomputed tensors 229 | q = qkv[:, 0] 230 | qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) 231 | dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) 232 | out = get_flash_attn_out_from_global_buffer(ctx.layer_idx) 233 | 234 | # --- flash attention backward --- 235 | dout = args[0] 236 | _flash_attn_varlen_backward( 237 | dout, 238 | qkv[:, 0], 239 | qkv[:, 1], 240 | qkv[:, 2], 241 | out, 242 | softmax_lse, 243 | dqkv[:, 0], 244 | dqkv[:, 1], 245 | dqkv[:, 2], 246 | cu_q_lens, 247 | cu_q_lens, 248 | max_s, 249 | max_s, 250 | 0.0, 251 | ctx.softmax_scale, 252 | True, 253 | rng_state=None, 254 | ) 255 | dqkv = dqkv[..., : dout.shape[-1]] 256 | 257 | # run backward for the part before flash attention 258 | qkv.backward(dqkv) 259 | 260 | grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None 261 | for inp in detached_inputs) 262 | 263 | # write flash attention output gradients to buffer 264 | if ctx.layer_idx > 0: 265 | write_gradient_to_flash_attn_out(ctx.layer_idx-1, detached_inputs[0].grad) 266 | 267 | return (None, None, None) + grads 268 | 269 | 270 | def checkpoint_end_with_flash_attention(function, layer_idx, *args, use_reentrant: bool = True, **kwargs): 271 | # Hack to mix *args with **kwargs in a python 2.7-compliant way 272 | preserve = kwargs.pop('preserve_rng_state', True) 273 | if kwargs and use_reentrant: 274 | raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) 275 | 276 | return CheckpointFunctionEndWithFlashAttention.apply(function, layer_idx, preserve, *args) 277 | 278 | 279 | class CheckpointFunctionLastModule(torch.autograd.Function): 280 | """ 281 | for the last ffn layer after flash attention, modifications include: 282 | write the gradients wrt flash attention output and residual to the global buffer. 283 | """ 284 | 285 | @staticmethod 286 | def forward(ctx, run_function, preserve_rng_state, *args): 287 | check_backward_validity(args) 288 | ctx.run_function = run_function 289 | ctx.preserve_rng_state = preserve_rng_state 290 | # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. 291 | ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() 292 | if preserve_rng_state: 293 | ctx.fwd_cpu_state = torch.get_rng_state() 294 | # Don't eagerly initialize the cuda context by accident. 295 | # (If the user intends that the context is initialized later, within their 296 | # run_function, we SHOULD actually stash the cuda state here. Unfortunately, 297 | # we have no way to anticipate this will happen before we run the function.) 298 | ctx.had_cuda_in_fwd = False 299 | if torch.cuda._initialized: 300 | ctx.had_cuda_in_fwd = True 301 | ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) 302 | 303 | # Save non-tensor inputs in ctx, keep a placeholder None for tensors 304 | # to be filled out during the backward. 305 | ctx.inputs = [] 306 | ctx.tensor_indices = [] 307 | tensor_inputs = [] 308 | 309 | assert torch.is_tensor(args[0]), "assuming the first tensor is the flash attention output" 310 | for i, arg in enumerate(args): 311 | if torch.is_tensor(arg) and i == 0: 312 | # flash attn output has been saved to global buffer 313 | ctx.inputs.append(None) 314 | elif torch.is_tensor(arg): 315 | tensor_inputs.append(arg) 316 | ctx.tensor_indices.append(i) 317 | ctx.inputs.append(None) 318 | else: 319 | ctx.inputs.append(arg) 320 | 321 | ctx.save_for_backward(*tensor_inputs) 322 | 323 | with torch.no_grad(): 324 | outputs = run_function(*args) 325 | return outputs 326 | 327 | @staticmethod 328 | def backward(ctx, *args): 329 | if not torch.autograd._is_checkpoint_valid(): 330 | raise RuntimeError( 331 | "Checkpointing is not compatible with .grad() or when an `inputs` parameter" 332 | " is passed to .backward(). Please use .backward() and do not pass its `inputs`" 333 | " argument.") 334 | # Copy the list to avoid modifying original list. 335 | inputs = list(ctx.inputs) 336 | tensor_indices = ctx.tensor_indices 337 | tensors = ctx.saved_tensors 338 | 339 | # Fill in inputs with appropriate saved tensors. 340 | # Fill the flash attention output first 341 | # inputs[0] should be flash attention output 342 | inputs[0] = get_flash_attn_out_from_global_buffer(-1) 343 | for i, idx in enumerate(tensor_indices): 344 | inputs[idx] = tensors[i] 345 | 346 | # Stash the surrounding rng state, and mimic the state that was 347 | # present at this time during forward. Restore the surrounding state 348 | # when we're done. 349 | rng_devices = [] 350 | if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: 351 | rng_devices = ctx.fwd_gpu_devices 352 | with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): 353 | if ctx.preserve_rng_state: 354 | torch.set_rng_state(ctx.fwd_cpu_state) 355 | if ctx.had_cuda_in_fwd: 356 | set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) 357 | detached_inputs = detach_variable(tuple(inputs)) 358 | with torch.enable_grad(), \ 359 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ 360 | torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): 361 | outputs = ctx.run_function(*detached_inputs) 362 | 363 | if isinstance(outputs, torch.Tensor): 364 | outputs = (outputs,) 365 | 366 | # run backward() with only tensor that requires grad 367 | outputs_with_grad = [] 368 | args_with_grad = [] 369 | for i in range(len(outputs)): 370 | if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: 371 | outputs_with_grad.append(outputs[i]) 372 | args_with_grad.append(args[i]) 373 | if len(outputs_with_grad) == 0: 374 | raise RuntimeError( 375 | "none of output has requires_grad=True," 376 | " this checkpoint() is not necessary") 377 | torch.autograd.backward(outputs_with_grad, args_with_grad) 378 | grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None 379 | for inp in detached_inputs) 380 | 381 | # write flash attention output gradients to buffer 382 | write_gradient_to_flash_attn_out(-1, detached_inputs[0].grad) 383 | 384 | return (None, None) + grads 385 | 386 | def checkpoint_last_module(function, *args, use_reentrant: bool = True, **kwargs): 387 | preserve = kwargs.pop('preserve_rng_state', True) 388 | if kwargs and use_reentrant: 389 | raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) 390 | 391 | return CheckpointFunctionLastModule.apply(function, preserve, *args) 392 | 393 | 394 | def llama_layer_forward( 395 | self, 396 | hidden_states: torch.Tensor, 397 | attention_mask: Optional[torch.Tensor] = None, 398 | position_ids: Optional[torch.LongTensor] = None, 399 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 400 | output_attentions: Optional[bool] = False, 401 | use_cache: Optional[bool] = False, 402 | compute_attn_only: Optional[bool] = False, 403 | compute_ffn_only: Optional[bool] = False, 404 | residual: Optional[bool] = None, 405 | indices: Optional[bool] = None, 406 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 407 | """ 408 | Args: 409 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 410 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 411 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 412 | output_attentions (`bool`, *optional*): 413 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 414 | returned tensors for more detail. 415 | use_cache (`bool`, *optional*): 416 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 417 | (see `past_key_values`). 418 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 419 | """ 420 | assert compute_ffn_only or compute_attn_only 421 | 422 | if compute_attn_only: 423 | residual = hidden_states 424 | 425 | if residual.requires_grad: 426 | # register a hook to add the gradient of residual 427 | # from next checkpoint layer when doing recomputation 428 | hook = residual.register_hook(load_and_add_res_grad_hook) 429 | global_hooks.append(hook) 430 | 431 | hidden_states = self.input_layernorm(hidden_states) 432 | 433 | # Flash Attention 434 | bsz, q_len, _ = hidden_states.size() 435 | query_states = self.self_attn.q_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) 436 | key_states = self.self_attn.k_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) 437 | value_states = self.self_attn.v_proj(hidden_states).view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2) 438 | 439 | kv_seq_len = key_states.shape[-2] 440 | assert past_key_value is None, "past_key_value is not supported by fastckpt" 441 | 442 | cos, sin = self.self_attn.rotary_emb(value_states, seq_len=kv_seq_len) 443 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 444 | # [bsz, nh, t, hd] 445 | assert not output_attentions, "output_attentions is not supported by fastckpt" 446 | assert not use_cache, "use_cache is not supported by fastckpt" 447 | key_padding_mask = attention_mask 448 | 449 | return query_states, key_states, value_states, attention_mask, residual 450 | 451 | elif compute_ffn_only: 452 | # pad 453 | bsz, q_len, _ = residual.size() 454 | if attention_mask is None: 455 | hidden_states = rearrange(hidden_states, '(b s) h d -> b s (h d)', b=bsz) 456 | else: 457 | assert indices is not None 458 | hidden_states = pad_input(rearrange(hidden_states, 'nnz h d -> nnz (h d)'), indices, bsz, q_len) 459 | 460 | hidden_states = self.self_attn.o_proj(hidden_states) 461 | 462 | # Need to add residual here to make sure checkpoint is right after flash attention 463 | if residual.requires_grad: 464 | # save the gradient of residual to the local buffer 465 | # collect the hooks which should be removed after backward to avoid memory leak 466 | hook = residual.register_hook(save_res_grad_hook) 467 | global_hooks.append(hook) 468 | 469 | hidden_states = residual + hidden_states 470 | 471 | # Fully Connected 472 | 473 | residual = hidden_states 474 | hidden_states = self.post_attention_layernorm(hidden_states) 475 | hidden_states = self.mlp(hidden_states) 476 | hidden_states = residual + hidden_states 477 | 478 | outputs = (hidden_states,) 479 | 480 | else: 481 | raise AttributeError 482 | 483 | return outputs 484 | 485 | 486 | # monkey patch for LlamaDecoderLayer 487 | def forward( 488 | self, 489 | input_ids: torch.LongTensor = None, 490 | attention_mask: Optional[torch.Tensor] = None, 491 | position_ids: Optional[torch.LongTensor] = None, 492 | past_key_values: Optional[List[torch.FloatTensor]] = None, 493 | inputs_embeds: Optional[torch.FloatTensor] = None, 494 | use_cache: Optional[bool] = None, 495 | output_attentions: Optional[bool] = None, 496 | output_hidden_states: Optional[bool] = None, 497 | return_dict: Optional[bool] = None, 498 | ): 499 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 500 | output_hidden_states = ( 501 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 502 | ) 503 | use_cache = use_cache if use_cache is not None else self.config.use_cache 504 | 505 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 506 | 507 | # retrieve input_ids and inputs_embeds 508 | if input_ids is not None and inputs_embeds is not None: 509 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 510 | elif input_ids is not None: 511 | batch_size, seq_length = input_ids.shape 512 | elif inputs_embeds is not None: 513 | batch_size, seq_length, _ = inputs_embeds.shape 514 | else: 515 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 516 | 517 | seq_length_with_past = seq_length 518 | past_key_values_length = 0 519 | 520 | if past_key_values is not None: 521 | past_key_values_length = past_key_values[0][0].shape[2] 522 | seq_length_with_past = seq_length_with_past + past_key_values_length 523 | 524 | if position_ids is None: 525 | device = input_ids.device if input_ids is not None else inputs_embeds.device 526 | position_ids = torch.arange( 527 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 528 | ) 529 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 530 | else: 531 | position_ids = position_ids.view(-1, seq_length).long() 532 | 533 | if inputs_embeds is None: 534 | inputs_embeds = self.embed_tokens(input_ids) 535 | 536 | hidden_states = inputs_embeds 537 | 538 | if self.gradient_checkpointing and self.training: 539 | try: 540 | logger.warning_once( 541 | "***** Using fast gradient checkpointing... *****" 542 | ) 543 | except: 544 | pass 545 | # initialize the global buffer 546 | init_flash_attn_buffers(len(self.layers)) 547 | 548 | if use_cache: 549 | try: 550 | logger.warning_once( 551 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 552 | ) 553 | except: 554 | pass 555 | use_cache = False 556 | 557 | # decoder layers 558 | all_hidden_states = () if output_hidden_states else None 559 | all_self_attns = () if output_attentions else None 560 | next_decoder_cache = () if use_cache else None 561 | 562 | # apply flash-attention friendly gradient checkpointing 563 | if self.gradient_checkpointing and self.training: 564 | for idx in range(len(self.layers) + 1): 565 | if output_hidden_states: 566 | all_hidden_states += (hidden_states,) 567 | 568 | past_key_value = past_key_values[idx] if past_key_values is not None else None 569 | 570 | # HF gradient checkpointing checkpoints (hidden_states, attention_mask, position_ids) 571 | # We only checkpoints (hidden_states) as the others are shared across layers 572 | 573 | def forward_first_attn_module(module): 574 | def custom_forward(*inputs): 575 | # hidden_states, attention_mask, position_ids, _ = inputs 576 | hidden_states, _, _ = inputs 577 | # None for past_key_value 578 | return module(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_attn_only=True) 579 | return custom_forward 580 | 581 | def forward_ffn_attn_layer(module1, module2): 582 | def custom_forward(*inputs): 583 | hidden_states, residual, indices = inputs 584 | # None for past_key_value 585 | layer_outputs = module1(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_ffn_only=True, residual=residual, indices=indices) 586 | hidden_states = layer_outputs[0] 587 | return module2(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_attn_only=True) 588 | return custom_forward 589 | 590 | def forward_last_ffn_module(module): 591 | def custom_forward(*inputs): 592 | hidden_states, residual, indices = inputs 593 | # None for past_key_value 594 | return module(hidden_states, attention_mask, position_ids, past_key_value, output_attentions, compute_ffn_only=True, residual=residual, indices=indices) 595 | return custom_forward 596 | 597 | if idx == 0: 598 | layer_outputs = checkpoint_end_with_flash_attention( 599 | forward_first_attn_module(self.layers[0]), 600 | idx, 601 | hidden_states, 602 | None, 603 | None, 604 | ) 605 | hidden_states, residual, indices = layer_outputs 606 | elif idx == len(self.layers): 607 | layer_outputs = checkpoint_last_module( 608 | forward_last_ffn_module(self.layers[-1]), 609 | hidden_states, 610 | residual, 611 | indices, 612 | ) 613 | hidden_states = layer_outputs[0] 614 | else: 615 | layer_outputs = checkpoint_end_with_flash_attention( 616 | forward_ffn_attn_layer(self.layers[idx-1], self.layers[idx]), 617 | idx, 618 | hidden_states, 619 | residual, 620 | indices, 621 | ) 622 | hidden_states, residual, indices = layer_outputs 623 | 624 | if use_cache: 625 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 626 | 627 | if output_attentions: 628 | all_self_attns += (layer_outputs[1],) 629 | else: 630 | for idx, decoder_layer in enumerate(self.layers): 631 | if output_hidden_states: 632 | all_hidden_states += (hidden_states,) 633 | 634 | past_key_value = past_key_values[idx] if past_key_values is not None else None 635 | 636 | layer_outputs = decoder_layer( 637 | hidden_states, 638 | attention_mask=attention_mask, 639 | position_ids=position_ids, 640 | past_key_value=past_key_value, 641 | output_attentions=output_attentions, 642 | use_cache=use_cache, 643 | ) 644 | 645 | hidden_states = layer_outputs[0] 646 | 647 | if use_cache: 648 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 649 | 650 | if output_attentions: 651 | all_self_attns += (layer_outputs[1],) 652 | 653 | hidden_states = self.norm(hidden_states) 654 | 655 | # add hidden states from the last decoder layer 656 | if output_hidden_states: 657 | all_hidden_states += (hidden_states,) 658 | 659 | next_cache = next_decoder_cache if use_cache else None 660 | if not return_dict: 661 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 662 | return BaseModelOutputWithPast( 663 | last_hidden_state=hidden_states, 664 | past_key_values=next_cache, 665 | hidden_states=all_hidden_states, 666 | attentions=all_self_attns, 667 | ) 668 | 669 | 670 | def replace_hf_ckpt_with_fast_ckpt(): 671 | transformers.models.llama.modeling_llama.LlamaModel.forward = forward 672 | transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = llama_layer_forward -------------------------------------------------------------------------------- /fastckpt/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 8 | 9 | from einops import rearrange 10 | 11 | #from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | def forward( 16 | self, 17 | hidden_states: torch.Tensor, 18 | attention_mask: Optional[torch.Tensor] = None, 19 | position_ids: Optional[torch.Tensor] = None, 20 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 21 | output_attentions: bool = False, 22 | use_cache: bool = False, 23 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], 24 | Optional[Tuple[torch.Tensor]]]: 25 | """Input shape: Batch x Time x Channel 26 | 27 | attention_mask: [bsz, q_len] 28 | """ 29 | bsz, q_len, _ = hidden_states.size() 30 | 31 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 32 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 33 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 34 | # [bsz, q_len, nh, hd] 35 | # [bsz, nh, q_len, hd] 36 | 37 | kv_seq_len = key_states.shape[-2] 38 | assert past_key_value is None, "past_key_value is not supported" 39 | 40 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 41 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 42 | # [bsz, nh, t, hd] 43 | assert not output_attentions, "output_attentions is not supported" 44 | assert not use_cache, "use_cache is not supported" 45 | 46 | # Flash attention codes from 47 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 48 | 49 | # transform the data into the format required by flash attention 50 | qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] 51 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 52 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 53 | # the attention_mask should be the same as the key_padding_mask 54 | key_padding_mask = attention_mask 55 | 56 | 57 | if key_padding_mask is None: 58 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 59 | max_s = q_len 60 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, 61 | device=qkv.device) 62 | #torch.cuda.synchronize() 63 | #torch.cuda.empty_cache() 64 | #torch.cuda.reset_peak_memory_stats() 65 | #_max_memory_start = torch.cuda.max_memory_allocated() 66 | output = flash_attn_varlen_qkvpacked_func(# flash_attn_unpadded_qkvpacked_func( 67 | qkv, cu_q_lens, max_s, 0.0, 68 | softmax_scale=None, causal=True 69 | ) 70 | #torch.cuda.synchronize() 71 | #print(f"flash attn peak:{(torch.cuda.max_memory_allocated() - _max_memory_start) / 2 ** 20}") 72 | output = rearrange(output, '(b s) ... -> b s ...', b=bsz) 73 | else: 74 | nheads = qkv.shape[-2] 75 | x = rearrange(qkv, 'b s three h d -> b s (three h d)') 76 | 77 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 78 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) 79 | output_unpad = flash_attn_varlen_qkvpacked_func( #flash_attn_unpadded_qkvpacked_func( 80 | x_unpad, cu_q_lens, max_s, 0.0, 81 | softmax_scale=None, causal=True 82 | ) 83 | #torch.cuda.synchronize() 84 | #print(f"flash attn peak:{(torch.cuda.max_memory_allocated() - _max_memory_start) / 2 ** 20}") 85 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), 86 | indices, bsz, q_len), 87 | 'b s (h d) -> b s h d', h=nheads) 88 | return self.o_proj(rearrange(output, 89 | 'b s h d -> b s (h d)')), None, None 90 | 91 | 92 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 93 | # requires the attention mask to be the same as the key_padding_mask 94 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, 95 | inputs_embeds, past_key_values_length): 96 | # [bsz, seq_len] 97 | return attention_mask 98 | 99 | 100 | def replace_llama_attn_with_flash_attn(): 101 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 102 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 103 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "fastckpt" 7 | version = "0.0.4" 8 | description = "A fast gradient checkpointing strategy for training with memory-efficient attention (e.g., FlashAttention)." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "aiohttp", "fastapi", "httpx", "markdown2[all]", "nh3", "numpy", 17 | "prompt_toolkit>=3.0.0", "pydantic<2,>=1", "requests", "rich>=10.0.0", 18 | "shortuuid", "tiktoken", "uvicorn", 19 | ] 20 | 21 | [project.optional-dependencies] 22 | model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"] 23 | webui = ["gradio"] 24 | train = ["einops", "flash-attn>=2.0", "wandb"] 25 | llm_judge = ["openai", "anthropic>=0.3", "ray"] 26 | dev = ["black==23.3.0", "pylint==2.8.2"] 27 | 28 | [project.urls] 29 | "Homepage" = "https://github.com/RulinShao/FastCkpt" 30 | "Bug Tracker" = "https://github.com/RulinShao/FaskCkpt/issues" 31 | 32 | [tool.setuptools.packages.find] 33 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 34 | 35 | [tool.wheel] 36 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Tests 2 | 3 | ## Test numerical difference 4 | We provide a test file for numberical difference in `test_numerical_difference.py`. The test verifies there is no numerical difference between HF checkpointing + FlashAttention and FastCkpt + FlashAttention. -------------------------------------------------------------------------------- /tests/test_numerical_difference.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | def test_ckpt(test_hf_grad_ckpt=False, sequence_length=1024, batch_size=1, repeat=True): 5 | # Need to call this before importing transformers. 6 | if test_hf_grad_ckpt: 7 | from fastckpt.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 8 | replace_llama_attn_with_flash_attn() 9 | else: 10 | from fastckpt.llama_flash_attn_ckpt_monkey_patch import replace_hf_ckpt_with_fast_ckpt, clear_all_buffers_at_the_end_of_training 11 | replace_hf_ckpt_with_fast_ckpt() 12 | 13 | import transformers 14 | 15 | model_name_or_path = "Llama-2-7b-chat-hf" 16 | model = transformers.AutoModelForCausalLM.from_pretrained( 17 | model_name_or_path, 18 | ).half().cuda() 19 | 20 | 21 | model.model.gradient_checkpointing = True 22 | model.train() 23 | warmup_steps = 20 24 | total_steps = 60 25 | 26 | torch.manual_seed(42) 27 | inputs_embeds = torch.randn((batch_size, sequence_length, 4096), requires_grad=True).half().cuda() 28 | torch.cuda.synchronize() 29 | time_per_iter = [] 30 | for i in range(total_steps): 31 | start = time.time() 32 | outputs = model(inputs_embeds=inputs_embeds, use_cache=False) 33 | loss = torch.mean(outputs.logits) 34 | loss.backward() 35 | torch.cuda.synchronize() 36 | end = time.time() 37 | if i >= warmup_steps: 38 | time_per_iter.append(end-start) 39 | if not test_hf_grad_ckpt: 40 | clear_all_buffers_at_the_end_of_training() 41 | if not repeat: 42 | break 43 | 44 | avg_time = sum(time_per_iter) / len(time_per_iter) if repeat else end-start 45 | print(f"Avg forward + backward spent {avg_time}s.") 46 | out = [outputs.logits.detach()] 47 | out += [model.model.layers[1].mlp.up_proj.weight.grad] 48 | out += [model.model.layers[1].self_attn.q_proj.weight.grad] 49 | out += [model.model.layers[0].mlp.up_proj.weight.grad] 50 | out += [model.model.layers[0].self_attn.q_proj.weight.grad] 51 | return out 52 | 53 | 54 | def test_numerical_difference(): 55 | my_ckpt_out = test_ckpt(False, repeat=False) 56 | hf_ckpt_out = test_ckpt(True, repeat=False) 57 | for i, (my_out, hf_out) in enumerate(zip(my_ckpt_out, hf_ckpt_out)): 58 | assert torch.allclose(my_out, hf_out) 59 | print(f"Passed {i}-th check!") 60 | 61 | 62 | if __name__ == '__main__': 63 | test_numerical_difference() --------------------------------------------------------------------------------